Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
017a2a8c
Unverified
Commit
017a2a8c
authored
Jan 31, 2023
by
OlivierDehaene
Committed by
GitHub
Jan 31, 2023
Browse files
feat: Add token streaming using ServerSideEvents support (#41)
parent
54fec931
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1014 additions
and
483 deletions
+1014
-483
Cargo.lock
Cargo.lock
+2
-0
launcher/Cargo.toml
launcher/Cargo.toml
+1
-1
launcher/tests/bloom_560m.json
launcher/tests/bloom_560m.json
+4
-2
launcher/tests/mt0_base.json
launcher/tests/mt0_base.json
+4
-2
proto/generate.proto
proto/generate.proto
+41
-25
router/Cargo.toml
router/Cargo.toml
+2
-0
router/client/src/client.rs
router/client/src/client.rs
+14
-14
router/client/src/lib.rs
router/client/src/lib.rs
+2
-1
router/client/src/sharded_client.rs
router/client/src/sharded_client.rs
+19
-19
router/src/db.rs
router/src/db.rs
+14
-36
router/src/infer.rs
router/src/infer.rs
+353
-0
router/src/lib.rs
router/src/lib.rs
+17
-4
router/src/server.rs
router/src/server.rs
+207
-100
router/src/validation.rs
router/src/validation.rs
+49
-24
server/tests/models/test_bloom.py
server/tests/models/test_bloom.py
+61
-44
server/tests/models/test_causal_lm.py
server/tests/models/test_causal_lm.py
+53
-47
server/tests/models/test_santacoder.py
server/tests/models/test_santacoder.py
+14
-16
server/tests/models/test_seq2seq_lm.py
server/tests/models/test_seq2seq_lm.py
+49
-44
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+57
-53
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+51
-51
No files found.
Cargo.lock
View file @
017a2a8c
...
...
@@ -1829,6 +1829,7 @@ dependencies = [
name = "text-generation-router"
version = "0.1.0"
dependencies = [
"async-stream",
"axum",
"clap 4.0.22",
"futures",
...
...
@@ -1841,6 +1842,7 @@ dependencies = [
"thiserror",
"tokenizers",
"tokio",
"tokio-stream",
"tracing",
"tracing-subscriber",
]
...
...
launcher/Cargo.toml
View file @
017a2a8c
...
...
@@ -16,4 +16,4 @@ tracing-subscriber = { version = "0.3.16", features = ["json"] }
[dev-dependencies]
float_eq
=
"1.0.1"
reqwest
=
{
version
=
"0.11.13"
,
features
=
[
"blocking"
,
"json"
]
}
serde
=
"1.0.150"
serde
=
{
version
=
"1.0.150"
,
features
=
["derive"]
}
launcher/tests/bloom_560m.json
View file @
017a2a8c
...
...
@@ -3,7 +3,7 @@
"details"
:
{
"finish_reason"
:
"length"
,
"generated_tokens"
:
20
,
"
tokens
"
:
[
"
prefill
"
:
[
[
10264
,
"Test"
,
...
...
@@ -13,7 +13,9 @@
8821
,
" request"
,
-11.895094
],
]
],
"tokens"
:
[
[
17
,
"."
,
...
...
launcher/tests/mt0_base.json
View file @
017a2a8c
...
...
@@ -3,12 +3,14 @@
"details"
:
{
"finish_reason"
:
"length"
,
"generated_tokens"
:
20
,
"
tokens
"
:
[
"
prefill
"
:
[
[
0
,
"<pad>"
,
null
],
]
],
"tokens"
:
[
[
259
,
""
,
...
...
proto/generate.proto
View file @
017a2a8c
...
...
@@ -7,10 +7,10 @@ service TextGenerationService {
rpc
ServiceDiscovery
(
ServiceDiscoveryRequest
)
returns
(
ServiceDiscoveryResponse
)
{}
/// Empties batch cache
rpc
ClearCache
(
ClearCacheRequest
)
returns
(
ClearCacheResponse
);
///
Generate tokens for a batch
rpc
Generate
(
Generate
Request
)
returns
(
Generate
Response
);
///
Generat
e token
s
for a list of
cach
ed batches
rpc
GenerateWithCache
(
GenerateWithCach
eRequest
)
returns
(
GenerateWithCach
eResponse
);
///
Prefill batch and decode first token
rpc
Prefill
(
Prefill
Request
)
returns
(
Prefill
Response
);
///
Decod
e token for a list of
prefill
ed batches
rpc
Decode
(
Decod
eRequest
)
returns
(
Decod
eResponse
);
}
/// Empty request
...
...
@@ -70,44 +70,60 @@ message Batch {
}
message
GeneratedText
{
/// Request
Request
request
=
1
;
/// Output
string
output_
text
=
2
;
string
text
=
1
;
/// Number of generated tokens
uint32
generated_tokens
=
3
;
/// Tokens
repeated
string
tokens
=
4
;
/// Token IDs
repeated
uint32
token_ids
=
5
;
/// Logprobs
repeated
float
logprobs
=
6
;
uint32
generated_tokens
=
2
;
/// Finish reason
string
finish_reason
=
7
;
string
finish_reason
=
3
;
/// Seed
optional
uint64
seed
=
8
;
optional
uint64
seed
=
4
;
}
message
GenerateRequest
{
message
PrefillTokens
{
/// Prefill Token IDs
repeated
uint32
ids
=
1
;
/// Prefill Logprobs
repeated
float
logprobs
=
2
;
/// Prefill tokens
repeated
string
texts
=
3
;
}
message
Generation
{
/// Request ID
uint64
request_id
=
1
;
/// Prefill tokens (optional)
PrefillTokens
prefill_tokens
=
2
;
/// Token ID
uint32
token_id
=
3
;
/// Logprob
float
token_logprob
=
4
;
/// Text
string
token_text
=
5
;
/// Complete generated text
GeneratedText
generated_text
=
6
;
}
message
PrefillRequest
{
/// Batch
Batch
batch
=
1
;
}
message
Generate
Response
{
///
Finished requests
repeated
Generat
edText
generat
ed_text
s
=
1
;
message
Prefill
Response
{
///
Generation
repeated
Generat
ion
generat
ion
s
=
1
;
/// Next batch (cached)
optional
Batch
batch
=
2
;
}
message
GenerateWithCach
eRequest
{
message
Decod
eRequest
{
/// Cached batches
repeated
Batch
batches
=
1
;
}
message
GenerateWithCach
eResponse
{
///
Finished request
s
repeated
Generat
edText
generat
ed_text
s
=
1
;
message
Decod
eResponse
{
///
Decode
s
repeated
Generat
ion
generat
ion
s
=
1
;
/// Next batch (cached)
optional
Batch
batch
=
2
;
}
}
\ No newline at end of file
router/Cargo.toml
View file @
017a2a8c
...
...
@@ -13,6 +13,7 @@ name = "text-generation-router"
path
=
"src/main.rs"
[dependencies]
async-stream
=
"0.3.3"
axum
=
{
version
=
"0.5.16"
,
features
=
[
"json"
,
"serde_json"
]
}
text-generation-client
=
{
path
=
"client"
}
clap
=
{
version
=
"4.0.15"
,
features
=
[
"derive"
,
"env"
]
}
...
...
@@ -25,6 +26,7 @@ serde_json = "1.0.85"
thiserror
=
"1.0.37"
tokenizers
=
"0.13.0"
tokio
=
{
version
=
"1.21.1"
,
features
=
[
"rt"
,
"rt-multi-thread"
,
"parking_lot"
,
"signal"
,
"sync"
]
}
tokio-stream
=
"0.1.11"
tracing
=
"0.1.36"
tracing-subscriber
=
{
version
=
"0.3.15"
,
features
=
["json"]
}
router/client/src/client.rs
View file @
017a2a8c
...
...
@@ -70,36 +70,36 @@ impl Client {
/// Generate one token for each request in the given batch
///
/// Returns
a list of generated texts of request that met their stopping criteria
/// Returns
Generation for each request in batch
/// and the next cached batch
#[instrument(skip(self))]
pub
async
fn
generate
(
&
mut
self
,
batch
:
Batch
)
->
Result
<
(
Vec
<
Generat
edText
>
,
Option
<
Batch
>
)
>
{
let
request
=
tonic
::
Request
::
new
(
Generate
Request
{
batch
:
Some
(
batch
)
});
pub
async
fn
prefill
(
&
mut
self
,
batch
:
Batch
)
->
Result
<
(
Vec
<
Generat
ion
>
,
Option
<
Batch
>
)
>
{
let
request
=
tonic
::
Request
::
new
(
Prefill
Request
{
batch
:
Some
(
batch
)
});
let
response
=
self
.stub
.
generate
(
request
)
.instrument
(
info_span!
(
"
generate
"
))
.
prefill
(
request
)
.instrument
(
info_span!
(
"
prefill
"
))
.await
?
.into_inner
();
Ok
((
response
.generat
ed_text
s
,
response
.batch
))
Ok
((
response
.generat
ion
s
,
response
.batch
))
}
/// Generate one token for each request in the given cached batch
/// Generate one token for each request in the given cached batch
es
///
/// Returns
a list of generated texts of request that met their stopping criteria
/// Returns
Generation for each request in batches
/// and the next cached batch
#[instrument(skip(self))]
pub
async
fn
generate_with_cach
e
(
pub
async
fn
decod
e
(
&
mut
self
,
batches
:
Vec
<
Batch
>
,
)
->
Result
<
(
Vec
<
Generat
edText
>
,
Option
<
Batch
>
)
>
{
let
request
=
tonic
::
Request
::
new
(
GenerateWithCach
eRequest
{
batches
});
)
->
Result
<
(
Vec
<
Generat
ion
>
,
Option
<
Batch
>
)
>
{
let
request
=
tonic
::
Request
::
new
(
Decod
eRequest
{
batches
});
let
response
=
self
.stub
.
generate_with_cach
e
(
request
)
.instrument
(
info_span!
(
"
generate_with_cach
e"
))
.
decod
e
(
request
)
.instrument
(
info_span!
(
"
decod
e"
))
.await
?
.into_inner
();
Ok
((
response
.generat
ed_text
s
,
response
.batch
))
Ok
((
response
.generat
ion
s
,
response
.batch
))
}
}
router/client/src/lib.rs
View file @
017a2a8c
...
...
@@ -7,7 +7,8 @@ mod sharded_client;
pub
use
client
::
Client
;
pub
use
pb
::
generate
::
v1
::{
Batch
,
GeneratedText
,
NextTokenChooserParameters
,
Request
,
StoppingCriteriaParameters
,
Batch
,
GeneratedText
,
Generation
,
NextTokenChooserParameters
,
PrefillTokens
,
Request
,
StoppingCriteriaParameters
,
};
pub
use
sharded_client
::
ShardedClient
;
use
thiserror
::
Error
;
...
...
router/client/src/sharded_client.rs
View file @
017a2a8c
/// Multi shard Client
use
crate
::
Result
;
use
crate
::{
Batch
,
Client
,
Generat
edText
};
use
crate
::{
Batch
,
Client
,
Generat
ion
};
use
futures
::
future
::
join_all
;
use
futures
::
future
::
select_all
;
use
tonic
::
transport
::
Uri
;
...
...
@@ -37,46 +37,46 @@ impl ShardedClient {
Self
::
from_master_client
(
master_client
)
.await
}
/// Clear the past generations cache
pub
async
fn
clear_cache
(
&
mut
self
)
->
Result
<
()
>
{
let
futures
:
Vec
<
_
>
=
self
.clients
.iter_mut
()
.map
(|
client
|
client
.clear_cache
())
.collect
();
join_all
(
futures
)
.await
.into_iter
()
.collect
()
}
/// Generate one token for each request in the given batch
///
/// Returns
a list of generated texts of request that met their stopping criteria
/// Returns
Generation for each request in batch
/// and the next cached batch
pub
async
fn
generate
(
&
mut
self
,
batch
:
Batch
)
->
Result
<
(
Vec
<
Generat
edText
>
,
Option
<
Batch
>
)
>
{
pub
async
fn
prefill
(
&
mut
self
,
batch
:
Batch
)
->
Result
<
(
Vec
<
Generat
ion
>
,
Option
<
Batch
>
)
>
{
let
futures
:
Vec
<
_
>
=
self
.clients
.iter_mut
()
.map
(|
client
|
Box
::
pin
(
client
.
generate
(
batch
.clone
())))
.map
(|
client
|
Box
::
pin
(
client
.
prefill
(
batch
.clone
())))
.collect
();
// As soon as we receive one response, we can return as all shards will return the same
let
(
result
,
_
,
_
)
=
select_all
(
futures
)
.await
;
result
}
/// Generate one token for each request in the given cached batch
/// Generate one token for each request in the given cached batch
es
///
/// Returns
a list of generated texts of request that met their stopping criteria
/// Returns
Generation for each request in batches
/// and the next cached batch
pub
async
fn
generate_with_cach
e
(
pub
async
fn
decod
e
(
&
mut
self
,
batches
:
Vec
<
Batch
>
,
)
->
Result
<
(
Vec
<
Generat
edText
>
,
Option
<
Batch
>
)
>
{
)
->
Result
<
(
Vec
<
Generat
ion
>
,
Option
<
Batch
>
)
>
{
let
futures
:
Vec
<
_
>
=
self
.clients
.iter_mut
()
.map
(|
client
|
Box
::
pin
(
client
.
generate_with_cach
e
(
batches
.clone
())))
.map
(|
client
|
Box
::
pin
(
client
.
decod
e
(
batches
.clone
())))
.collect
();
// As soon as we receive one response, we can return as all shards will return the same
let
(
result
,
_
,
_
)
=
select_all
(
futures
)
.await
;
result
}
/// Clear the past generations cache
pub
async
fn
clear_cache
(
&
mut
self
)
->
Result
<
()
>
{
let
futures
:
Vec
<
_
>
=
self
.clients
.iter_mut
()
.map
(|
client
|
client
.clear_cache
())
.collect
();
join_all
(
futures
)
.await
.into_iter
()
.collect
()
}
}
router/src/db.rs
View file @
017a2a8c
/// This code is massively inspired by Tokio mini-redis
use
crate
::
InferResponse
;
use
crate
::{
GenerateParameters
,
GenerateRequest
};
use
crate
::
infer
::
InferError
;
use
crate
::
infer
::
InferStreamResponse
;
use
crate
::
validation
::
ValidGenerateRequest
;
use
nohash_hasher
::{
BuildNoHashHasher
,
IntMap
};
use
parking_lot
::
Mutex
;
use
std
::
collections
::
BTreeMap
;
use
std
::
sync
::
Arc
;
use
text_generation_client
::{
Batch
,
ClientError
,
NextTokenChooserParameters
,
Request
,
StoppingCriteriaParameters
,
};
use
tokio
::
sync
::
oneshot
::
Sender
;
use
text_generation_client
::{
Batch
,
Request
};
use
tokio
::
sync
::
mpsc
::
UnboundedSender
;
use
tokio
::
sync
::
OwnedSemaphorePermit
;
use
tokio
::
time
::
Instant
;
/// Database entry
#[derive(Debug)]
pub
(
crate
)
struct
Entry
{
/// Request
pub
request
:
GenerateRequest
,
/// Response sender to communicate between the Batcher and the batching_task
pub
response_tx
:
Sender
<
Result
<
InferResponse
,
ClientError
>>
,
/// Number of tokens in the input
pub
input_length
:
usize
,
pub
request
:
ValidGenerateRequest
,
/// Response sender to communicate between the Infer struct and the batching_task
pub
response_tx
:
UnboundedSender
<
Result
<
InferStreamResponse
,
InferError
>>
,
/// Instant when this entry was created
pub
time
:
Instant
,
/// Instant when this entry was added to a batch
pub
batch_time
:
Option
<
Instant
>
,
/// Permit
pub
_
permit
:
OwnedSemaphorePermit
,
}
/// Request Database
...
...
@@ -71,9 +71,9 @@ impl State {
requests
.push
(
Request
{
id
:
*
id
,
inputs
:
entry
.request.inputs
.clone
(),
input_length
:
entry
.input_length
as
u32
,
parameters
:
Some
(
(
&
entry
.request.parameters
)
.into
()),
stopping_parameters
:
Some
(
entry
.request.parameters
.clone
()
.into
()
),
input_length
:
entry
.
request.
input_length
,
parameters
:
Some
(
entry
.request.parameters
.clone
()),
stopping_parameters
:
Some
(
entry
.request.
stopping_
parameters
.clone
()),
});
ids
.push
(
*
id
);
...
...
@@ -158,25 +158,3 @@ impl Db {
None
}
}
impl
From
<&
GenerateParameters
>
for
NextTokenChooserParameters
{
fn
from
(
parameters
:
&
GenerateParameters
)
->
Self
{
Self
{
temperature
:
parameters
.temperature
,
top_k
:
parameters
.top_k
as
u32
,
top_p
:
parameters
.top_p
,
do_sample
:
parameters
.do_sample
,
// FIXME: remove unwrap
seed
:
parameters
.seed
.unwrap
(),
}
}
}
impl
From
<
GenerateParameters
>
for
StoppingCriteriaParameters
{
fn
from
(
parameters
:
GenerateParameters
)
->
Self
{
Self
{
stop_sequences
:
parameters
.stop
,
max_new_tokens
:
parameters
.max_new_tokens
,
}
}
}
router/src/
batch
er.rs
→
router/src/
inf
er.rs
View file @
017a2a8c
/// Batching and inference logic
use
crate
::{
Db
,
Entry
};
use
crate
::{
ErrorResponse
,
GenerateRequest
};
use
axum
::
http
::
StatusCode
;
use
axum
::
Json
;
use
crate
::
validation
::{
Validation
,
ValidationError
};
use
crate
::
GenerateRequest
;
use
crate
::{
Db
,
Entry
,
Token
};
use
nohash_hasher
::
IntMap
;
use
std
::
future
::
Future
;
use
std
::
sync
::
Arc
;
use
text_generation_client
::{
Batch
,
ClientError
,
GeneratedText
,
ShardedClient
};
use
text_generation_client
::{
Batch
,
ClientError
,
GeneratedText
,
Generation
,
PrefillTokens
,
ShardedClient
,
};
use
thiserror
::
Error
;
use
tokio
::
sync
::{
oneshot
,
Notify
};
use
tokio
::
sync
::{
mpsc
,
Notify
,
Semaphore
,
TryAcquireError
};
use
tokio
::
time
::
Instant
;
use
tokio_stream
::
wrappers
::
UnboundedReceiverStream
;
use
tokio_stream
::
StreamExt
;
use
tracing
::
instrument
;
///
Batcher
///
Inference struct
#[derive(Clone)]
pub
struct
Batcher
{
pub
struct
Infer
{
/// Validation
validation
:
Validation
,
/// Request database
db
:
Db
,
/// Shared state
shared
:
Arc
<
Shared
>
,
/// Inference limit
limit_concurrent_requests
:
Arc
<
Semaphore
>
,
}
///
Batch
er shared state
///
Inf
er shared state
struct
Shared
{
/// Batching background Tokio task notifier
batching_task
:
Notify
,
}
impl
Batch
er
{
impl
Inf
er
{
pub
(
crate
)
fn
new
(
client
:
ShardedClient
,
validation
:
Validation
,
max_batch_size
:
usize
,
max_waiting_tokens
:
usize
,
max_concurrent_requests
:
usize
,
)
->
Self
{
//
Batch
er shared state
//
Inf
er shared state
let
db
=
Db
::
new
();
let
shared
=
Arc
::
new
(
Shared
{
batching_task
:
Notify
::
new
(),
...
...
@@ -48,37 +57,111 @@ impl Batcher {
shared
.clone
(),
));
Self
{
db
,
shared
}
// Inference limit with a semaphore
let
semaphore
=
Arc
::
new
(
Semaphore
::
new
(
max_concurrent_requests
));
Self
{
validation
,
db
,
shared
,
limit_concurrent_requests
:
semaphore
,
}
}
/// Add a new request to the database and return a
future that will generate the text
pub
(
crate
)
async
fn
infer
(
/// Add a new request to the database and return a
stream of InferStreamResponse
pub
(
crate
)
async
fn
generate_stream
(
&
self
,
input_length
:
usize
,
request
:
GenerateRequest
,
)
->
Result
<
InferResponse
,
InferError
>
{
// One shot channel to communicate with the background batching task
let
(
response_tx
,
response_rx
)
=
oneshot
::
channel
();
)
->
Result
<
UnboundedReceiverStream
<
Result
<
InferStreamResponse
,
InferError
>>
,
InferError
>
{
// Limit concurrent requests by acquiring a permit from the semaphore
// This permit will live as long as Entry
let
permit
=
self
.clone
()
.limit_concurrent_requests
.try_acquire_owned
()
?
;
// Validate request
let
valid_request
=
self
.validation
.validate
(
request
)
.await
?
;
// Try to append the request to the database
// MPSC channel to communicate with the background batching task
let
(
response_tx
,
response_rx
)
=
mpsc
::
unbounded_channel
();
// Append the request to the database
self
.db
.append
(
Entry
{
request
,
request
:
valid_
request
,
response_tx
,
input_length
,
time
:
Instant
::
now
(),
batch_time
:
None
,
_
permit
:
permit
,
});
// Notify the background task that we have a new entry in the database that needs
// to be batched
self
.shared.batching_task
.notify_one
();
// Await on the response from the background task
// We can safely unwrap as the background task will never drop the sender
response_rx
.await
.unwrap
()
.map_err
(|
err
|
InferError
::
GenerationError
(
err
.to_string
()))
// Return stream
Ok
(
UnboundedReceiverStream
::
new
(
response_rx
))
}
/// Add a new request to the database and return a InferResponse
pub
(
crate
)
async
fn
generate
(
&
self
,
request
:
GenerateRequest
,
)
->
Result
<
InferResponse
,
InferError
>
{
// Create stream
let
mut
stream
=
self
.generate_stream
(
request
)
.await
?
;
// Return values
let
mut
result_prefill
=
Vec
::
new
();
let
mut
result_tokens
=
Vec
::
new
();
let
mut
result_generated_text
=
None
;
let
mut
result_start
=
None
;
let
mut
result_queued
=
None
;
// Iterate on stream
while
let
Some
(
response
)
=
stream
.next
()
.await
{
match
response
?
{
// Add prefill tokens
InferStreamResponse
::
Prefill
(
tokens
)
=>
{
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
result_prefill
=
tokens
.ids
.into_iter
()
.zip
(
tokens
.logprobs
.into_iter
())
.zip
(
tokens
.texts
.into_iter
())
.map
(|((
id
,
logprob
),
text
)|
Token
(
id
,
text
,
logprob
))
.collect
();
}
// Push last token
InferStreamResponse
::
Token
(
token
)
=>
result_tokens
.push
(
token
),
// Final message
// Set return values
InferStreamResponse
::
End
{
token
,
generated_text
,
start
,
queued
,
}
=>
{
result_tokens
.push
(
token
);
result_generated_text
=
Some
(
generated_text
);
result_start
=
Some
(
start
);
result_queued
=
Some
(
queued
)
}
}
}
// Check that we received a `InferStreamResponse::End` message
if
let
(
Some
(
generated_text
),
Some
(
queued
),
Some
(
start
))
=
(
result_generated_text
,
result_queued
,
result_start
)
{
Ok
(
InferResponse
{
prefill
:
result_prefill
,
tokens
:
result_tokens
,
generated_text
,
queued
,
start
,
})
}
else
{
Err
(
InferError
::
IncompleteGeneration
)
}
}
}
...
...
@@ -99,14 +182,14 @@ async fn batching_task(
// Infinite loop
loop
{
// Wait for a notification from the
Batch
er struct
// Wait for a notification from the
Inf
er struct
shared
.batching_task
.notified
()
.await
;
// Get the next batch from the DB
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the DB
while
let
Some
((
mut
entries
,
batch
))
=
db
.next_batch
(
None
,
max_batch_size
)
{
let
mut
cached_batch
=
wrap_future
(
client
.
generate
(
batch
),
&
mut
entries
)
.await
;
let
mut
cached_batch
=
wrap_future
(
client
.
prefill
(
batch
),
&
mut
entries
)
.await
;
let
mut
waiting_tokens
=
1
;
// We loop until we do not receive any cached batch from the inference server (== until
...
...
@@ -132,7 +215,7 @@ async fn batching_task(
{
// Generate one token for this new batch to have the attention past in cache
let
new_cached_batch
=
wrap_future
(
client
.
generate
(
new_batch
),
&
mut
new_entries
)
.await
;
wrap_future
(
client
.
prefill
(
new_batch
),
&
mut
new_entries
)
.await
;
// Reset waiting counter
waiting_tokens
=
1
;
// Extend current batch with the new batch
...
...
@@ -143,21 +226,21 @@ async fn batching_task(
}
}
cached_batch
=
wrap_future
(
client
.
generate_with_cach
e
(
batches
),
&
mut
entries
)
.await
;
cached_batch
=
wrap_future
(
client
.
decod
e
(
batches
),
&
mut
entries
)
.await
;
waiting_tokens
+=
1
;
}
}
}
}
/// Wrap a future inside a match statement to handle errors and send the response to
the Batch
er
/// Wrap a future inside a match statement to handle errors and send the response
s
to
Inf
er
async
fn
wrap_future
(
future
:
impl
Future
<
Output
=
Result
<
(
Vec
<
Generat
edText
>
,
Option
<
Batch
>
),
ClientError
>>
,
future
:
impl
Future
<
Output
=
Result
<
(
Vec
<
Generat
ion
>
,
Option
<
Batch
>
),
ClientError
>>
,
entries
:
&
mut
IntMap
<
u64
,
Entry
>
,
)
->
Option
<
Batch
>
{
match
future
.await
{
Ok
((
generat
ed_text
s
,
next_batch
))
=>
{
send_generat
ed
(
generat
ed_text
s
,
entries
);
Ok
((
generat
ion
s
,
next_batch
))
=>
{
send_generat
ions
(
generat
ion
s
,
entries
);
next_batch
}
// If we have an error, we discard the whole batch
...
...
@@ -168,69 +251,103 @@ async fn wrap_future(
}
}
/// Send errors to
the Batch
er for all `entries`
/// Send errors to
Inf
er for all `entries`
fn
send_error
(
error
:
ClientError
,
entries
:
&
mut
IntMap
<
u64
,
Entry
>
)
{
entries
.drain
()
.for_each
(|(
_
,
entry
)|
{
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send
(
Err
(
error
.clone
()))
.unwrap_or
(());
entry
.response_tx
.send
(
Err
(
InferError
::
GenerationError
(
error
.to_string
())))
.unwrap_or
(());
});
}
/// Send `generated_text` to the Batcher for all `finished`
fn
send_generated
(
finished
:
Vec
<
GeneratedText
>
,
entries
:
&
mut
IntMap
<
u64
,
Entry
>
)
{
finished
.into_iter
()
.for_each
(|
output
|
{
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
fn
send_generations
(
generations
:
Vec
<
Generation
>
,
entries
:
&
mut
IntMap
<
u64
,
Entry
>
)
{
generations
.into_iter
()
.for_each
(|
generation
|
{
// Get entry
// We can `expect` here as the request id should always be in the entries
let
entry
=
entries
.
remove
(
&
output
.request
.unwrap
()
.
id
)
.
get
(
&
generation
.request_
id
)
.expect
(
"ID not found in entries. This is a bug."
);
let
response
=
InferResponse
{
output_text
:
output
.output_text
,
generated_tokens
:
output
.generated_tokens
,
token_ids
:
output
.token_ids
,
tokens
:
output
.tokens
,
logprobs
:
output
.logprobs
,
finish_reason
:
output
.finish_reason
,
seed
:
output
.seed
,
queued
:
entry
.time
,
start
:
entry
.batch_time
.unwrap
(),
// unwrap is always valid
end
:
Instant
::
now
(),
};
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send
(
Ok
(
response
))
.unwrap_or
(());
if
let
Some
(
prefill_tokens
)
=
generation
.prefill_tokens
{
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send
(
Ok
(
InferStreamResponse
::
Prefill
(
prefill_tokens
)))
.unwrap_or
(());
}
// Create last Token
let
token
=
Token
(
generation
.token_id
,
generation
.token_text
,
generation
.token_logprob
,
);
if
let
Some
(
generated_text
)
=
generation
.generated_text
{
// Remove entry as this is the last message
// We can `expect` here as the request id should always be in the entries
let
entry
=
entries
.remove
(
&
generation
.request_id
)
.expect
(
"ID not found in entries. This is a bug."
);
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send
(
Ok
(
InferStreamResponse
::
End
{
token
,
generated_text
,
queued
:
entry
.time
,
start
:
entry
.batch_time
.unwrap
(),
}))
.unwrap_or
(());
}
else
{
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send
(
Ok
(
InferStreamResponse
::
Token
(
token
)))
.unwrap_or
(());
}
});
}
#[derive(Debug)]
pub
(
crate
)
enum
InferStreamResponse
{
// Optional first message
Prefill
(
PrefillTokens
),
// Intermediate messages
Token
(
Token
),
// Last message
End
{
token
:
Token
,
generated_text
:
GeneratedText
,
start
:
Instant
,
queued
:
Instant
,
},
}
#[derive(Debug)]
pub
(
crate
)
struct
InferResponse
{
pub
(
crate
)
output_text
:
String
,
pub
(
crate
)
generated_tokens
:
u32
,
pub
(
crate
)
token_ids
:
Vec
<
u32
>
,
pub
(
crate
)
tokens
:
Vec
<
String
>
,
pub
(
crate
)
logprobs
:
Vec
<
f32
>
,
pub
(
crate
)
finish_reason
:
String
,
pub
(
crate
)
seed
:
Option
<
u64
>
,
pub
(
crate
)
prefill
:
Vec
<
Token
>
,
pub
(
crate
)
tokens
:
Vec
<
Token
>
,
pub
(
crate
)
generated_text
:
GeneratedText
,
pub
(
crate
)
queued
:
Instant
,
pub
(
crate
)
start
:
Instant
,
pub
(
crate
)
end
:
Instant
,
}
#[derive(Debug,
Error)]
pub
enum
InferError
{
#[error(
"Request failed during generation: {0}"
)]
GenerationError
(
String
),
}
/// Convert to Axum supported format
impl
From
<
InferError
>
for
(
StatusCode
,
Json
<
ErrorResponse
>
)
{
fn
from
(
err
:
InferError
)
->
Self
{
match
err
{
InferError
::
GenerationError
(
_
)
=>
(
StatusCode
::
FAILED_DEPENDENCY
,
Json
(
ErrorResponse
{
error
:
err
.to_string
(),
}),
),
}
}
#[error(
"Model is overloaded"
)]
Overloaded
(
#[from]
TryAcquireError
),
#[error(
"Input validation error: {0}"
)]
ValidationError
(
#[from]
ValidationError
),
#[error(
"Incomplete generation"
)]
IncompleteGeneration
,
}
router/src/lib.rs
View file @
017a2a8c
/// Text Generation Inference Webserver
mod
batcher
;
mod
db
;
mod
infer
;
pub
mod
server
;
mod
validation
;
use
batcher
::{
Batcher
,
InferResponse
};
use
db
::{
Db
,
Entry
};
use
infer
::
Infer
;
use
serde
::{
Deserialize
,
Serialize
};
use
validation
::
Validation
;
...
...
@@ -69,21 +69,34 @@ pub(crate) struct GenerateRequest {
pub
parameters
:
GenerateParameters
,
}
#[derive(Debug,
Serialize)]
pub
struct
Token
(
u32
,
String
,
f32
);
#[derive(Serialize)]
pub
(
crate
)
struct
Details
{
pub
finish_reason
:
String
,
pub
generated_tokens
:
u32
,
pub
seed
:
Option
<
u64
>
,
pub
tokens
:
Vec
<
(
u32
,
String
,
f32
)
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
prefill
:
Option
<
Vec
<
Token
>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
tokens
:
Option
<
Vec
<
Token
>>
,
}
#[derive(Serialize)]
pub
(
crate
)
struct
Generate
dText
{
pub
(
crate
)
struct
Generate
Response
{
pub
generated_text
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
details
:
Option
<
Details
>
,
}
#[derive(Serialize)]
pub
(
crate
)
struct
StreamResponse
{
pub
token
:
Token
,
pub
generated_text
:
Option
<
String
>
,
pub
details
:
Option
<
Details
>
,
}
#[derive(Serialize)]
pub
(
crate
)
struct
ErrorResponse
{
pub
error
:
String
,
...
...
router/src/server.rs
View file @
017a2a8c
/// HTTP Server logic
use
crate
::
infer
::{
InferError
,
InferStreamResponse
};
use
crate
::{
Batcher
,
Details
,
ErrorResponse
,
GenerateParameters
,
GenerateRequest
,
GeneratedText
,
Validation
,
Details
,
ErrorResponse
,
GenerateParameters
,
GenerateRequest
,
GenerateResponse
,
Infer
,
StreamResponse
,
Validation
,
};
use
axum
::
extract
::
Extension
;
use
axum
::
http
::{
HeaderMap
,
StatusCode
};
use
axum
::
response
::
sse
::{
Event
,
KeepAlive
,
Sse
};
use
axum
::
response
::
IntoResponse
;
use
axum
::
routing
::{
get
,
post
};
use
axum
::{
Json
,
Router
};
use
futures
::
Stream
;
use
std
::
convert
::
Infallible
;
use
std
::
net
::
SocketAddr
;
use
std
::
sync
::
Arc
;
use
text_generation_client
::
ShardedClient
;
use
tokenizers
::
Tokenizer
;
use
tokio
::
signal
;
use
tokio
::
sync
::
Semaphore
;
use
tokio
::
time
::
Instant
;
use
tokio_stream
::
StreamExt
;
use
tracing
::
instrument
;
// Server shared state
#[derive(Clone)]
struct
ServerState
{
validation
:
Validation
,
batcher
:
Batcher
,
limit_concurrent_requests
:
Arc
<
Semaphore
>
,
}
/// Health check method
#[instrument(skip(
state),
fields(time,
time_per_token
))]
async
fn
health
(
state
:
Extension
<
ServerState
>
)
->
Result
<
(),
(
StatusCode
,
Json
<
ErrorResponse
>
)
>
{
#[instrument(skip(
infer
))]
async
fn
health
(
infer
:
Extension
<
Infer
>
)
->
Result
<
(),
(
StatusCode
,
Json
<
ErrorResponse
>
)
>
{
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check.
// What we should do instead if check if the gRPC channels are still healthy.
// Limit concurrent requests by acquiring a permit from the semaphore
let
_
permit
=
state
.limit_concurrent_requests
.try_acquire
()
.map_err
(|
_
|
{
(
StatusCode
::
TOO_MANY_REQUESTS
,
Json
(
ErrorResponse
{
error
:
"Model is overloaded"
.to_string
(),
}),
)
})
?
;
// Send a small inference request
state
.batcher
.infer
(
1
,
GenerateRequest
{
inputs
:
"liveness"
.to_string
(),
parameters
:
GenerateParameters
{
temperature
:
1.0
,
top_k
:
0
,
top_p
:
1.0
,
do_sample
:
false
,
max_new_tokens
:
1
,
stop
:
vec!
[],
details
:
false
,
seed
:
None
,
},
infer
.generate
(
GenerateRequest
{
inputs
:
"liveness"
.to_string
(),
parameters
:
GenerateParameters
{
temperature
:
1.0
,
top_k
:
0
,
top_p
:
1.0
,
do_sample
:
false
,
max_new_tokens
:
1
,
stop
:
vec!
[],
details
:
false
,
seed
:
None
,
},
)
}
)
.await
?
;
Ok
(())
}
/// Generate method
#[instrument(
skip(
state
),
skip(
infer
),
fields(
total_time,
validation_time,
...
...
@@ -76,56 +59,28 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
)
)]
async
fn
generate
(
state
:
Extension
<
ServerState
>
,
infer
:
Extension
<
Infer
>
,
req
:
Json
<
GenerateRequest
>
,
)
->
Result
<
impl
IntoResponse
,
(
StatusCode
,
Json
<
ErrorResponse
>
)
>
{
let
span
=
tracing
::
Span
::
current
();
let
start_time
=
Instant
::
now
();
// Limit concurrent requests by acquiring a permit from the semaphore
let
_
permit
=
state
.limit_concurrent_requests
.try_acquire
()
.map_err
(|
_
|
{
tracing
::
error!
(
"Model is overloaded"
);
(
StatusCode
::
TOO_MANY_REQUESTS
,
Json
(
ErrorResponse
{
error
:
"Model is overloaded"
.to_string
(),
}),
)
})
?
;
// Validate request
let
details
=
req
.0
.parameters.details
;
let
(
input_length
,
validated_request
)
=
state
.validation
.validate
(
req
.0
)
.await
.map_err
(|
err
|
{
tracing
::
error!
(
"{}"
,
err
.to_string
());
err
})
?
;
// Inference
let
response
=
state
.batcher
.infer
(
input_length
,
validated_request
)
.await
.map_err
(|
err
|
{
tracing
::
error!
(
"{}"
,
err
.to_string
());
err
})
?
;
let
details
=
req
.0
.parameters.details
;
let
response
=
infer
.generate
(
req
.0
)
.await
.map_err
(|
err
|
{
tracing
::
error!
(
"{}"
,
err
.to_string
());
err
})
?
;
// Token details
let
details
=
match
details
{
true
=>
{
let
tokens
=
response
.token_ids
.into_iter
()
.zip
(
response
.tokens
.into_iter
())
.zip
(
response
.logprobs
.into_iter
())
.map
(|((
id
,
text
),
logprob
)|
(
id
,
text
,
logprob
))
.collect
();
Some
(
Details
{
seed
:
response
.seed
,
finish_reason
:
response
.finish_reason
,
generated_tokens
:
response
.generated_tokens
,
tokens
,
})
}
true
=>
Some
(
Details
{
finish_reason
:
response
.generated_text.finish_reason
,
generated_tokens
:
response
.generated_text.generated_tokens
,
prefill
:
Some
(
response
.prefill
),
tokens
:
Some
(
response
.tokens
),
seed
:
response
.generated_text.seed
,
}),
false
=>
None
,
};
...
...
@@ -133,8 +88,8 @@ async fn generate(
let
total_time
=
start_time
.elapsed
();
let
validation_time
=
response
.queued
-
start_time
;
let
queue_time
=
response
.start
-
response
.queued
;
let
inference_time
=
response
.end
-
response
.start
;
let
time_per_token
=
inference_time
/
response
.generated_tokens
;
let
inference_time
=
Instant
::
now
()
-
response
.start
;
let
time_per_token
=
inference_time
/
response
.generated_
text.generated_
tokens
;
// Headers
let
mut
headers
=
HeaderMap
::
new
();
...
...
@@ -160,22 +115,143 @@ async fn generate(
);
// Tracing metadata
tracing
::
Span
::
current
()
.record
(
"total_time"
,
format!
(
"{:?}"
,
total_time
));
tracing
::
Span
::
current
()
.record
(
"validation_time"
,
format!
(
"{:?}"
,
validation_time
));
tracing
::
Span
::
current
()
.record
(
"queue_time"
,
format!
(
"{:?}"
,
queue_time
));
tracing
::
Span
::
current
()
.record
(
"inference_time"
,
format!
(
"{:?}"
,
inference_time
));
tracing
::
Span
::
current
()
.record
(
"time_per_token"
,
format!
(
"{:?}"
,
time_per_token
));
tracing
::
Span
::
current
()
.record
(
"seed"
,
format!
(
"{:?}"
,
response
.seed
));
tracing
::
info!
(
"Output: {}"
,
response
.
output_
text
);
span
.record
(
"total_time"
,
format!
(
"{:?}"
,
total_time
));
span
.record
(
"validation_time"
,
format!
(
"{:?}"
,
validation_time
));
span
.record
(
"queue_time"
,
format!
(
"{:?}"
,
queue_time
));
span
.record
(
"inference_time"
,
format!
(
"{:?}"
,
inference_time
));
span
.record
(
"time_per_token"
,
format!
(
"{:?}"
,
time_per_token
));
span
.record
(
"seed"
,
format!
(
"{:?}"
,
response
.
generated_text.
seed
));
tracing
::
info!
(
"Output: {}"
,
response
.
generated_text.
text
);
// Send response
let
response
=
vec!
[
Generate
dText
{
generated_text
:
response
.
output_
text
,
let
response
=
vec!
[
Generate
Response
{
generated_text
:
response
.
generated_text.
text
,
details
,
}];
Ok
((
headers
,
Json
(
response
)))
}
/// Generate stream method
#[instrument(
skip(infer),
fields(
total_time,
validation_time,
queue_time,
inference_time,
time_per_token
)
)]
async
fn
generate_stream
(
infer
:
Extension
<
Infer
>
,
req
:
Json
<
GenerateRequest
>
,
)
->
Sse
<
impl
Stream
<
Item
=
Result
<
Event
,
Infallible
>>>
{
let
span
=
tracing
::
Span
::
current
();
let
start_time
=
Instant
::
now
();
let
stream
=
async_stream
::
stream!
{
// Inference
let
mut
end_reached
=
false
;
let
mut
error
=
false
;
let
details
=
req
.0
.parameters.details
;
match
infer
.generate_stream
(
req
.0
)
.await
{
Ok
(
mut
response_stream
)
=>
{
// Server Side Event stream
while
let
Some
(
response
)
=
response_stream
.next
()
.await
{
match
response
{
Ok
(
response
)
=>
{
match
response
{
// Prefill is ignored
InferStreamResponse
::
Prefill
(
_
)
=>
{}
// Yield event for every new token
InferStreamResponse
::
Token
(
token
)
=>
{
// StreamResponse
let
stream_token
=
StreamResponse
{
token
,
generated_text
:
None
,
details
:
None
,
};
yield
Ok
(
Event
::
default
()
.json_data
(
stream_token
)
.unwrap
())
}
// Yield event for last token and compute timings
InferStreamResponse
::
End
{
token
,
generated_text
,
start
,
queued
,
}
=>
{
// Token details
let
details
=
match
details
{
true
=>
Some
(
Details
{
finish_reason
:
generated_text
.finish_reason
,
generated_tokens
:
generated_text
.generated_tokens
,
prefill
:
None
,
tokens
:
None
,
seed
:
generated_text
.seed
,
}),
false
=>
None
,
};
// Timings
let
total_time
=
start_time
.elapsed
();
let
validation_time
=
queued
-
start_time
;
let
queue_time
=
start
-
queued
;
let
inference_time
=
Instant
::
now
()
-
start
;
let
time_per_token
=
inference_time
/
generated_text
.generated_tokens
;
// Tracing metadata
span
.record
(
"total_time"
,
format!
(
"{:?}"
,
total_time
));
span
.record
(
"validation_time"
,
format!
(
"{:?}"
,
validation_time
));
span
.record
(
"queue_time"
,
format!
(
"{:?}"
,
queue_time
));
span
.record
(
"inference_time"
,
format!
(
"{:?}"
,
inference_time
));
span
.record
(
"time_per_token"
,
format!
(
"{:?}"
,
time_per_token
));
tracing
::
info!
(
parent
:
&
span
,
"Output: {}"
,
generated_text
.text
);
// StreamResponse
end_reached
=
true
;
let
stream_token
=
StreamResponse
{
token
,
generated_text
:
Some
(
generated_text
.text
),
details
};
yield
Ok
(
Event
::
default
()
.json_data
(
stream_token
)
.unwrap
())
}
}
}
// Trace and yield error
Err
(
err
)
=>
{
error
=
true
;
tracing
::
error!
(
"{}"
,
err
.to_string
());
yield
Ok
(
Event
::
from
(
err
))
}
}
}
},
// Trace and yield error
Err
(
err
)
=>
{
error
=
true
;
tracing
::
error!
(
"{}"
,
err
.to_string
());
yield
Ok
(
Event
::
from
(
err
))
}
}
// Check if generation reached the end
// Skip if we already sent an error
if
!
end_reached
&&
!
error
{
let
err
=
InferError
::
IncompleteGeneration
;
tracing
::
error!
(
"{}"
,
err
.to_string
());
yield
Ok
(
Event
::
from
(
err
))
}
};
Sse
::
new
(
stream
)
.keep_alive
(
KeepAlive
::
default
())
}
/// Serving method
#[allow(clippy::too_many_arguments)]
pub
async
fn
run
(
...
...
@@ -189,21 +265,23 @@ pub async fn run(
addr
:
SocketAddr
,
)
{
// Create state
let
batcher
=
Batcher
::
new
(
client
,
max_batch_size
,
max_waiting_tokens
);
let
validation
=
Validation
::
new
(
validation_workers
,
tokenizer
,
max_input_length
);
let
shared_state
=
ServerState
{
let
infer
=
Infer
::
new
(
client
,
validation
,
batcher
,
limit_concurrent_requests
:
Arc
::
new
(
Semaphore
::
new
(
max_concurrent_requests
)),
};
max_batch_size
,
max_waiting_tokens
,
max_concurrent_requests
,
);
// Create router
let
app
=
Router
::
new
()
.route
(
"/"
,
post
(
generate
))
.route
(
"/generate"
,
post
(
generate
))
.route
(
"/generate_stream"
,
post
(
generate_stream
))
.route
(
"/"
,
get
(
health
))
.route
(
"/health"
,
get
(
health
))
.layer
(
Extension
(
shared_state
.clone
()
));
.layer
(
Extension
(
infer
));
// Run server
axum
::
Server
::
bind
(
&
addr
)
...
...
@@ -240,3 +318,32 @@ async fn shutdown_signal() {
tracing
::
info!
(
"signal received, starting graceful shutdown"
);
}
/// Convert to Axum supported formats
impl
From
<
InferError
>
for
(
StatusCode
,
Json
<
ErrorResponse
>
)
{
fn
from
(
err
:
InferError
)
->
Self
{
let
status_code
=
match
err
{
InferError
::
GenerationError
(
_
)
=>
StatusCode
::
FAILED_DEPENDENCY
,
InferError
::
Overloaded
(
_
)
=>
StatusCode
::
TOO_MANY_REQUESTS
,
InferError
::
ValidationError
(
_
)
=>
StatusCode
::
UNPROCESSABLE_ENTITY
,
InferError
::
IncompleteGeneration
=>
StatusCode
::
INTERNAL_SERVER_ERROR
,
};
(
status_code
,
Json
(
ErrorResponse
{
error
:
err
.to_string
(),
}),
)
}
}
impl
From
<
InferError
>
for
Event
{
fn
from
(
err
:
InferError
)
->
Self
{
Event
::
default
()
.json_data
(
ErrorResponse
{
error
:
err
.to_string
(),
})
.unwrap
()
}
}
router/src/validation.rs
View file @
017a2a8c
/// Payload validation logic
use
crate
::{
ErrorResponse
,
GenerateRequest
};
use
axum
::
http
::
StatusCode
;
use
axum
::
Json
;
use
crate
::{
GenerateParameters
,
GenerateRequest
};
use
rand
::
rngs
::
ThreadRng
;
use
rand
::
Rng
;
use
text_generation_client
::{
NextTokenChooserParameters
,
StoppingCriteriaParameters
};
use
thiserror
::
Error
;
use
tokenizers
::
tokenizer
::
Tokenizer
;
use
tokio
::
sync
::{
mpsc
,
oneshot
};
...
...
@@ -40,7 +39,7 @@ impl Validation {
pub
(
crate
)
async
fn
validate
(
&
self
,
request
:
GenerateRequest
,
)
->
Result
<
(
usize
,
GenerateRequest
)
,
ValidationError
>
{
)
->
Result
<
Valid
GenerateRequest
,
ValidationError
>
{
// Create response channel
let
(
sender
,
receiver
)
=
oneshot
::
channel
();
// Send request to the background validation task
...
...
@@ -106,11 +105,11 @@ fn validation_worker(
}
fn
validate
(
mut
request
:
GenerateRequest
,
request
:
GenerateRequest
,
tokenizer
:
&
Tokenizer
,
max_input_length
:
usize
,
rng
:
&
mut
ThreadRng
,
)
->
Result
<
(
usize
,
GenerateRequest
)
,
ValidationError
>
{
)
->
Result
<
Valid
GenerateRequest
,
ValidationError
>
{
if
request
.parameters.temperature
<=
0.0
{
return
Err
(
ValidationError
::
Temperature
);
}
...
...
@@ -131,19 +130,48 @@ fn validate(
}
// If seed is None, assign a random one
if
request
.parameters.seed
.is_none
()
{
request
.parameters.seed
=
Some
(
rng
.gen
());
}
let
seed
=
match
request
.parameters.seed
{
None
=>
rng
.gen
(),
Some
(
seed
)
=>
seed
,
};
// Get the number of tokens in the input
match
tokenizer
.encode
(
request
.inputs
.clone
(),
true
)
{
Ok
(
inputs
)
=>
{
let
input_length
=
inputs
.len
();
Ok
(
encoding
)
=>
{
let
input_length
=
encoding
.len
();
if
input_length
>
max_input_length
{
Err
(
ValidationError
::
InputLength
(
input_length
,
max_input_length
))
}
else
{
Ok
((
input_length
,
request
))
// Return ValidGenerateRequest
let
GenerateParameters
{
temperature
,
top_k
,
top_p
,
do_sample
,
max_new_tokens
,
stop
:
stop_sequences
,
..
}
=
request
.parameters
;
let
parameters
=
NextTokenChooserParameters
{
temperature
,
top_k
:
top_k
as
u32
,
top_p
,
do_sample
,
seed
,
};
let
stopping_parameters
=
StoppingCriteriaParameters
{
max_new_tokens
,
stop_sequences
,
};
Ok
(
ValidGenerateRequest
{
inputs
:
request
.inputs
,
input_length
:
input_length
as
u32
,
parameters
,
stopping_parameters
,
})
}
}
Err
(
err
)
=>
Err
(
ValidationError
::
Tokenizer
(
err
.to_string
())),
...
...
@@ -152,9 +180,17 @@ fn validate(
type
ValidationRequest
=
(
GenerateRequest
,
oneshot
::
Sender
<
Result
<
(
usize
,
GenerateRequest
)
,
ValidationError
>>
,
oneshot
::
Sender
<
Result
<
Valid
GenerateRequest
,
ValidationError
>>
,
);
#[derive(Debug)]
pub
(
crate
)
struct
ValidGenerateRequest
{
pub
inputs
:
String
,
pub
input_length
:
u32
,
pub
parameters
:
NextTokenChooserParameters
,
pub
stopping_parameters
:
StoppingCriteriaParameters
,
}
#[derive(Error,
Debug)]
pub
enum
ValidationError
{
#[error(
"temperature must be strictly positive"
)]
...
...
@@ -172,14 +208,3 @@ pub enum ValidationError {
#[error(
"tokenizer error {0}"
)]
Tokenizer
(
String
),
}
impl
From
<
ValidationError
>
for
(
StatusCode
,
Json
<
ErrorResponse
>
)
{
fn
from
(
err
:
ValidationError
)
->
Self
{
(
StatusCode
::
UNPROCESSABLE_ENTITY
,
Json
(
ErrorResponse
{
error
:
err
.to_string
(),
}),
)
}
}
server/tests/models/test_bloom.py
View file @
017a2a8c
...
...
@@ -91,9 +91,9 @@ def test_causal_lm_batch_type(default_bloom):
def
test_causal_lm_generate_token
(
default_bloom
,
default_bloom_batch
):
sequence_length
=
len
(
default_bloom_batch
.
all_input_ids
[
0
])
generat
ed_text
s
,
next_batch
=
default_bloom
.
generate_token
(
default_bloom_batch
)
generat
ion
s
,
next_batch
=
default_bloom
.
generate_token
(
default_bloom_batch
)
assert
generat
ed_texts
==
[]
assert
len
(
generat
ions
)
==
len
(
default_bloom_batch
)
assert
isinstance
(
next_batch
,
CausalLMBatch
)
assert
not
next_batch
.
keys_head_dim_last
...
...
@@ -122,24 +122,30 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert
all
(
[
p
[
1
].
shape
==
(
16
,
sequence_length
,
64
)
for
p
in
next_batch
.
past_key_values
]
)
assert
all
([
generation
.
generated_text
is
None
for
generation
in
generations
])
assert
all
([
len
(
generation
.
prefill_tokens
)
==
1
for
generation
in
generations
])
assert
all
([
generation
.
token_id
.
item
()
==
10264
for
generation
in
generations
])
assert
all
([
generation
.
token_text
==
"Test"
for
generation
in
generations
])
assert
generations
[
0
].
request_id
==
0
def
test_causal_lm_generate_token_completion
(
default_bloom
,
default_bloom_batch
):
next_batch
=
default_bloom_batch
for
_
in
range
(
default_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
-
1
):
generat
ed_text
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
generat
ed_texts
==
[]
generat
ion
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
len
(
generat
ions
)
==
len
(
default_bloom_batch
)
generat
ed_text
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generat
ed_text
s
)
==
1
assert
len
(
generat
ion
s
)
==
1
assert
(
generated_texts
[
0
].
output_text
==
"TestTestTestTestTestTestTestTestTestTestTest"
generations
[
0
].
generated_text
.
text
==
"TestTestTestTestTestTestTestTestTestTestTest"
)
assert
generat
ed_text
s
[
0
].
request
==
default_bloom_batch
.
requests
[
0
]
assert
generat
ion
s
[
0
].
request
_id
==
default_bloom_batch
.
requests
[
0
]
.
id
assert
(
generated_text
s
[
0
]
.
generated_tokens
generations
[
0
].
generated_text
.
generated_tokens
==
default_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
...
...
@@ -152,17 +158,19 @@ def test_causal_lm_generate_token_completion_multi(
for
i
in
range
(
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
1
):
generat
ed_text
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
generat
ed_texts
==
[]
generat
ion
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
len
(
generat
ions
)
==
len
(
default_multi_requests_bloom_batch
)
generat
ed_text
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output_text
==
"TestTestTestTestTestTest"
assert
generated_texts
[
0
].
request
==
default_multi_requests_bloom_batch
.
requests
[
1
]
assert
len
(
generations
)
==
2
assert
generations
[
1
].
generated_text
.
text
==
"TestTestTestTestTestTest"
assert
(
generated_texts
[
0
].
generated_tokens
generations
[
1
].
request_id
==
default_multi_requests_bloom_batch
.
requests
[
1
].
id
)
assert
(
generations
[
1
].
generated_text
.
generated_tokens
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
)
...
...
@@ -171,19 +179,22 @@ def test_causal_lm_generate_token_completion_multi(
-
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
1
):
generat
ed_text
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
generat
ed_texts
==
[]
generat
ion
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
len
(
generat
ions
)
==
len
(
next_batch
)
generat
ed_text
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generations
)
==
1
assert
(
generations
[
0
].
generated_text
.
text
==
"TestTestTestTestTestTestTestTestTestTestTest"
)
assert
(
generat
ed_texts
[
0
].
output_text
==
"TestTestTestTestTestTestTestTestTestTestTest"
generat
ions
[
0
].
request_id
==
default_multi_requests_bloom_batch
.
requests
[
0
].
id
)
assert
generated_texts
[
0
].
request
==
default_multi_requests_bloom_batch
.
requests
[
0
]
assert
(
generated_text
s
[
0
]
.
generated_tokens
generations
[
0
].
generated_text
.
generated_tokens
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
...
...
@@ -243,17 +254,19 @@ def test_batch_concatenate(
for
_
in
range
(
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
2
):
generat
ed_text
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
generat
ed_texts
==
[]
generat
ion
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
len
(
generat
ions
)
==
len
(
next_batch
)
generat
ed_text
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output_text
==
"TestTestTestTestTestTest"
assert
generated_texts
[
0
].
request
==
default_multi_requests_bloom_batch
.
requests
[
1
]
assert
len
(
generations
)
==
3
assert
generations
[
2
].
generated_text
.
text
==
"TestTestTestTestTestTest"
assert
(
generated_texts
[
0
].
generated_tokens
generations
[
2
].
request_id
==
default_multi_requests_bloom_batch
.
requests
[
1
].
id
)
assert
(
generations
[
2
].
generated_text
.
generated_tokens
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
)
...
...
@@ -262,19 +275,20 @@ def test_batch_concatenate(
-
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
2
):
generat
ed_text
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
generat
ed_texts
==
[]
generat
ion
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
len
(
generat
ions
)
==
len
(
next_batch
)
generat
ed_text
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
len
(
generat
ed_text
s
)
==
1
assert
len
(
generat
ion
s
)
==
2
assert
(
generated_texts
[
0
].
output_text
==
"TestTestTestTestTestTestTestTestTestTestTest"
generations
[
0
].
generated_text
.
text
==
"TestTestTestTestTestTestTestTestTestTestTest"
)
assert
generat
ed_text
s
[
0
].
request
==
default_bloom_batch
.
requests
[
0
]
assert
generat
ion
s
[
0
].
request
_id
==
default_bloom_batch
.
requests
[
0
]
.
id
assert
(
generated_text
s
[
0
]
.
generated_tokens
generations
[
0
].
generated_text
.
generated_tokens
==
default_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
...
...
@@ -284,18 +298,21 @@ def test_batch_concatenate(
-
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
4
):
generat
ed_text
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
generat
ed_texts
==
[]
generat
ion
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
len
(
generat
ions
)
==
len
(
next_batch
)
generat
ed_text
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generations
)
==
1
assert
(
generations
[
0
].
generated_text
.
text
==
"TestTestTestTestTestTestTestTestTestTestTest"
)
assert
(
generat
ed_texts
[
0
].
output_text
==
"TestTestTestTestTestTestTestTestTestTestTest"
generat
ions
[
0
].
request_id
==
default_multi_requests_bloom_batch
.
requests
[
0
].
id
)
assert
generated_texts
[
0
].
request
==
default_multi_requests_bloom_batch
.
requests
[
0
]
assert
(
generated_text
s
[
0
]
.
generated_tokens
generations
[
0
].
generated_text
.
generated_tokens
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
server/tests/models/test_causal_lm.py
View file @
017a2a8c
...
...
@@ -88,11 +88,9 @@ def test_causal_lm_batch_type(default_causal_lm):
def
test_causal_lm_generate_token
(
default_causal_lm
,
default_causal_lm_batch
):
sequence_length
=
len
(
default_causal_lm_batch
.
all_input_ids
[
0
])
generated_texts
,
next_batch
=
default_causal_lm
.
generate_token
(
default_causal_lm_batch
)
generations
,
next_batch
=
default_causal_lm
.
generate_token
(
default_causal_lm_batch
)
assert
generat
ed_texts
==
[]
assert
len
(
generat
ions
)
==
len
(
next_batch
)
assert
isinstance
(
next_batch
,
CausalLMBatch
)
assert
len
(
next_batch
.
all_input_ids
)
==
next_batch
.
size
...
...
@@ -121,6 +119,11 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert
all
(
[
p
[
1
].
shape
==
(
1
,
12
,
sequence_length
,
64
)
for
p
in
next_batch
.
past_key_values
]
)
assert
all
([
generation
.
generated_text
is
None
for
generation
in
generations
])
assert
all
([
len
(
generation
.
prefill_tokens
)
==
1
for
generation
in
generations
])
assert
all
([
generation
.
token_id
.
item
()
==
13
for
generation
in
generations
])
assert
all
([
generation
.
token_text
==
"."
for
generation
in
generations
])
assert
generations
[
0
].
request_id
==
0
def
test_causal_lm_generate_token_completion
(
...
...
@@ -128,18 +131,17 @@ def test_causal_lm_generate_token_completion(
):
next_batch
=
default_causal_lm_batch
for
_
in
range
(
default_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
-
1
):
generat
ed_text
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
generat
ed_texts
==
[]
generat
ion
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
len
(
generat
ions
)
==
len
(
next_batch
)
generat
ed_text
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output_text
==
"Test.java:784) at net.minecraft."
assert
generated_texts
[
0
].
request
==
default_causal_lm_batch
.
requests
[
0
]
assert
len
(
generated_texts
[
0
].
tokens
)
==
len
(
generated_texts
[
0
].
logprobs
)
assert
len
(
generations
)
==
1
assert
generations
[
0
].
generated_text
.
text
==
"Test.java:784) at net.minecraft."
assert
generations
[
0
].
request_id
==
default_causal_lm_batch
.
requests
[
0
].
id
assert
(
generated_text
s
[
0
]
.
generated_tokens
generations
[
0
].
generated_text
.
generated_tokens
==
default_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
...
...
@@ -152,19 +154,20 @@ def test_causal_lm_generate_token_completion_multi(
for
i
in
range
(
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
1
):
generat
ed_text
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
generat
ed_texts
==
[]
generat
ion
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
len
(
generat
ions
)
==
len
(
next_batch
)
generat
ed_text
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
len
(
generat
ed_text
s
)
==
1
assert
generat
ed_texts
[
0
].
output_
text
==
"Test.java:784)"
assert
len
(
generat
ion
s
)
==
2
assert
generat
ions
[
1
].
generated_text
.
text
==
"Test.java:784)"
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
1
]
generations
[
1
].
request_id
==
default_multi_requests_causal_lm_batch
.
requests
[
1
].
id
)
assert
(
generated_text
s
[
0
]
.
generated_tokens
generations
[
1
].
generated_text
.
generated_tokens
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
)
...
...
@@ -173,19 +176,20 @@ def test_causal_lm_generate_token_completion_multi(
-
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
1
):
generat
ed_text
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
generat
ed_texts
==
[]
generat
ion
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
len
(
generat
ions
)
==
len
(
next_batch
)
generat
ed_text
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generat
ed_text
s
)
==
1
assert
generat
ed_texts
[
0
].
output_
text
==
"Test.java:784) at net.minecraft."
assert
len
(
generat
ion
s
)
==
1
assert
generat
ions
[
0
].
generated_text
.
text
==
"Test.java:784) at net.minecraft."
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
0
]
generations
[
0
].
request_id
==
default_multi_requests_causal_lm_batch
.
requests
[
0
].
id
)
assert
(
generated_text
s
[
0
]
.
generated_tokens
generations
[
0
].
generated_text
.
generated_tokens
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
...
...
@@ -244,19 +248,20 @@ def test_batch_concatenate(
for
_
in
range
(
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
2
):
generat
ed_text
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
generat
ed_texts
==
[]
generat
ion
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
len
(
generat
ions
)
==
len
(
next_batch
)
generat
ed_text
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
len
(
generat
ed_text
s
)
==
1
assert
generat
ed_texts
[
0
].
output_
text
==
"Test.java:784)"
assert
len
(
generat
ion
s
)
==
3
assert
generat
ions
[
2
].
generated_text
.
text
==
"Test.java:784)"
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
1
]
generations
[
2
].
request_id
==
default_multi_requests_causal_lm_batch
.
requests
[
1
].
id
)
assert
(
generated_text
s
[
0
]
.
generated_tokens
generations
[
2
].
generated_text
.
generated_tokens
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
)
...
...
@@ -265,17 +270,17 @@ def test_batch_concatenate(
-
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
2
):
generat
ed_text
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
generat
ed_texts
==
[]
generat
ion
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
len
(
generat
ions
)
==
len
(
next_batch
)
generat
ed_text
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
len
(
generat
ed_text
s
)
==
1
assert
generat
ed_texts
[
0
].
output_
text
==
"Test.java:784) at net.minecraft."
assert
generat
ed_text
s
[
0
].
request
==
default_causal_lm_batch
.
requests
[
0
]
assert
len
(
generat
ion
s
)
==
2
assert
generat
ions
[
0
].
generated_text
.
text
==
"Test.java:784) at net.minecraft."
assert
generat
ion
s
[
0
].
request
_id
==
default_causal_lm_batch
.
requests
[
0
]
.
id
assert
(
generated_text
s
[
0
]
.
generated_tokens
generations
[
0
].
generated_text
.
generated_tokens
==
default_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
...
...
@@ -285,18 +290,19 @@ def test_batch_concatenate(
-
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
4
):
generat
ed_text
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
generat
ed_texts
==
[]
generat
ion
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
len
(
generat
ions
)
==
len
(
next_batch
)
generat
ed_text
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generat
ed_text
s
)
==
1
assert
generat
ed_texts
[
0
].
output_
text
==
"Test.java:784) at net.minecraft."
assert
len
(
generat
ion
s
)
==
1
assert
generat
ions
[
0
].
generated_text
.
text
==
"Test.java:784) at net.minecraft."
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
0
]
generations
[
0
].
request_id
==
default_multi_requests_causal_lm_batch
.
requests
[
0
].
id
)
assert
(
generated_text
s
[
0
]
.
generated_tokens
generations
[
0
].
generated_text
.
generated_tokens
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
server/tests/models/test_santacoder.py
View file @
017a2a8c
...
...
@@ -50,18 +50,17 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
next_batch
=
batch
for
_
in
range
(
batch
.
stopping_criterias
[
0
].
max_new_tokens
-
1
):
generat
ed_text
s
,
next_batch
=
default_santacoder
.
generate_token
(
next_batch
)
assert
generat
ed_texts
==
[]
generat
ion
s
,
next_batch
=
default_santacoder
.
generate_token
(
next_batch
)
assert
len
(
generat
ions
)
==
len
(
next_batch
)
generat
ed_text
s
,
next_batch
=
default_santacoder
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_santacoder
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output_text
==
"def test_get_all_users_with_"
assert
generated_texts
[
0
].
request
==
batch
.
requests
[
0
]
assert
len
(
generated_texts
[
0
].
tokens
)
==
len
(
generated_texts
[
0
].
logprobs
)
assert
len
(
generations
)
==
1
assert
generations
[
0
].
generated_text
.
text
==
"def test_get_all_users_with_"
assert
generations
[
0
].
request_id
==
batch
.
requests
[
0
].
id
assert
(
generated_text
s
[
0
]
.
generated_tokens
generations
[
0
].
generated_text
.
generated_tokens
==
batch
.
stopping_criterias
[
0
].
max_new_tokens
)
...
...
@@ -76,20 +75,19 @@ def test_fim_santacoder_generate_token_completion(
next_batch
=
batch
for
_
in
range
(
batch
.
stopping_criterias
[
0
].
max_new_tokens
-
1
):
generat
ed_text
s
,
next_batch
=
default_santacoder
.
generate_token
(
next_batch
)
assert
generat
ed_texts
==
[]
generat
ion
s
,
next_batch
=
default_santacoder
.
generate_token
(
next_batch
)
assert
len
(
generat
ions
)
==
len
(
next_batch
)
generat
ed_text
s
,
next_batch
=
default_santacoder
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_santacoder
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generat
ed_text
s
)
==
1
assert
len
(
generat
ion
s
)
==
1
assert
(
generat
ed_texts
[
0
].
output_
text
generat
ions
[
0
].
generated_text
.
text
==
"""<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value"""
)
assert
generated_texts
[
0
].
request
==
batch
.
requests
[
0
]
assert
len
(
generated_texts
[
0
].
tokens
)
==
len
(
generated_texts
[
0
].
logprobs
)
assert
generations
[
0
].
request_id
==
batch
.
requests
[
0
].
id
assert
(
generated_text
s
[
0
]
.
generated_tokens
generations
[
0
].
generated_text
.
generated_tokens
==
batch
.
stopping_criterias
[
0
].
max_new_tokens
)
server/tests/models/test_seq2seq_lm.py
View file @
017a2a8c
...
...
@@ -99,11 +99,11 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
def
test_seq2seq_lm_generate_token
(
default_seq2seq_lm
,
default_seq2seq_lm_batch
):
sequence_length
=
len
(
default_seq2seq_lm_batch
.
input_ids
[
0
])
generat
ed_text
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
generat
ion
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
default_seq2seq_lm_batch
)
assert
generat
ed_texts
==
[]
assert
len
(
generat
ions
)
==
len
(
next_batch
)
assert
isinstance
(
next_batch
,
Seq2SeqLMBatch
)
assert
torch
.
equal
(
next_batch
.
input_ids
,
default_seq2seq_lm_batch
.
input_ids
)
...
...
@@ -145,6 +145,11 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
for
p
in
next_batch
.
past_key_values
]
)
assert
all
([
generation
.
generated_text
is
None
for
generation
in
generations
])
assert
all
([
len
(
generation
.
prefill_tokens
)
==
1
for
generation
in
generations
])
assert
all
([
generation
.
token_id
.
item
()
==
259
for
generation
in
generations
])
assert
all
([
generation
.
token_text
==
""
for
generation
in
generations
])
assert
generations
[
0
].
request_id
==
0
def
test_seq2seq_lm_generate_token_completion
(
...
...
@@ -152,16 +157,16 @@ def test_seq2seq_lm_generate_token_completion(
):
next_batch
=
default_seq2seq_lm_batch
for
_
in
range
(
6
):
generat
ed_text
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
generat
ed_texts
==
[]
generat
ion
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
len
(
generat
ions
)
==
len
(
next_batch
)
generat
ed_text
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generat
ed_text
s
)
==
1
assert
generat
ed_texts
[
0
].
output_
text
==
"a few weeks"
assert
generat
ed_text
s
[
0
].
request
==
default_seq2seq_lm_batch
.
requests
[
0
]
assert
generated_text
s
[
0
]
.
generated_tokens
==
7
assert
len
(
generat
ion
s
)
==
1
assert
generat
ions
[
0
].
generated_text
.
text
==
"a few weeks"
assert
generat
ion
s
[
0
].
request
_id
==
default_seq2seq_lm_batch
.
requests
[
0
]
.
id
assert
generations
[
0
].
generated_text
.
generated_tokens
==
7
def
test_seq2seq_lm_generate_token_completion_multi
(
...
...
@@ -170,33 +175,33 @@ def test_seq2seq_lm_generate_token_completion_multi(
next_batch
=
default_multi_requests_seq2seq_lm_batch
for
i
in
range
(
4
):
generat
ed_text
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
generat
ed_texts
==
[]
generat
ion
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
len
(
generat
ions
)
==
len
(
next_batch
)
generat
ed_text
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
len
(
generat
ed_text
s
)
==
1
assert
generat
ed_texts
[
0
].
output_
text
==
"a few "
assert
len
(
generat
ion
s
)
==
2
assert
generat
ions
[
1
].
generated_text
.
text
==
"a few "
assert
(
generat
ed_text
s
[
0
].
request
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
1
]
generat
ion
s
[
1
].
request
_id
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
1
]
.
id
)
assert
generated_text
s
[
0
]
.
generated_tokens
==
5
assert
generations
[
1
].
generated_text
.
generated_tokens
==
5
generat
ed_text
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
generat
ed_texts
==
[]
generat
ion
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
len
(
generat
ions
)
==
len
(
next_batch
)
generat
ed_text
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generat
ed_text
s
)
==
1
assert
generat
ed_texts
[
0
].
output_
text
==
"a few weeks"
assert
len
(
generat
ion
s
)
==
1
assert
generat
ions
[
0
].
generated_text
.
text
==
"a few weeks"
assert
(
generat
ed_text
s
[
0
].
request
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
0
]
generat
ion
s
[
0
].
request
_id
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
0
]
.
id
)
assert
generated_text
s
[
0
]
.
generated_tokens
==
7
assert
generations
[
0
].
generated_text
.
generated_tokens
==
7
def
test_batch_concatenate
(
...
...
@@ -291,35 +296,35 @@ def test_batch_concatenate(
)
for
_
in
range
(
3
):
generat
ed_text
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
generat
ed_texts
==
[]
generat
ion
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
len
(
generat
ions
)
==
len
(
next_batch
)
generat
ed_text
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
len
(
generat
ed_text
s
)
==
1
assert
generat
ed_texts
[
0
].
output_
text
==
"a few "
assert
len
(
generat
ion
s
)
==
3
assert
generat
ions
[
2
].
generated_text
.
text
==
"a few "
assert
(
generat
ed_text
s
[
0
].
request
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
1
]
generat
ion
s
[
2
].
request
_id
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
1
]
.
id
)
assert
generated_text
s
[
0
]
.
generated_tokens
==
5
assert
generations
[
2
].
generated_text
.
generated_tokens
==
5
generat
ed_text
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
len
(
generat
ed_text
s
)
==
1
assert
generat
ed_texts
[
0
].
output_
text
==
"a few weeks"
assert
generat
ed_text
s
[
0
].
request
==
default_seq2seq_lm_batch
.
requests
[
0
]
assert
generated_text
s
[
0
]
.
generated_tokens
==
7
assert
len
(
generat
ion
s
)
==
2
assert
generat
ions
[
0
].
generated_text
.
text
==
"a few weeks"
assert
generat
ion
s
[
0
].
request
_id
==
default_seq2seq_lm_batch
.
requests
[
0
]
.
id
assert
generations
[
0
].
generated_text
.
generated_tokens
==
7
generat
ed_text
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generat
ion
s
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generat
ed_text
s
)
==
1
assert
generat
ed_texts
[
0
].
output_
text
==
"a few weeks"
assert
len
(
generat
ion
s
)
==
1
assert
generat
ions
[
0
].
generated_text
.
text
==
"a few weeks"
assert
(
generat
ed_text
s
[
0
].
request
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
0
]
generat
ion
s
[
0
].
request
_id
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
0
]
.
id
)
assert
generated_text
s
[
0
]
.
generated_tokens
==
7
assert
generations
[
0
].
generated_text
.
generated_tokens
==
7
server/text_generation/models/causal_lm.py
View file @
017a2a8c
...
...
@@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
from
typing
import
Optional
,
Tuple
,
List
,
Type
from
text_generation.models
import
Model
from
text_generation.models.types
import
GeneratedText
,
Batch
from
text_generation.models.types
import
Batch
,
PrefillTokens
,
Generation
,
GeneratedText
from
text_generation.pb
import
generate_pb2
from
text_generation.utils
import
NextTokenChooser
,
StoppingCriteria
,
Sampling
...
...
@@ -23,7 +23,6 @@ class CausalLMBatch(Batch):
# All tokens
all_input_ids
:
List
[
torch
.
Tensor
]
all_logprobs
:
List
[
Optional
[
torch
.
Tensor
]]
# Lengths of all generations present in the batch
input_lengths
:
List
[
int
]
...
...
@@ -57,7 +56,6 @@ class CausalLMBatch(Batch):
next_token_choosers
=
[]
stopping_criterias
=
[]
input_lengths
=
[]
all_logprobs
=
[]
# Parse batch
for
r
in
pb
.
requests
:
...
...
@@ -67,7 +65,6 @@ class CausalLMBatch(Batch):
stopping_criterias
.
append
(
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
)
all_logprobs
.
append
(
None
)
pad_to_multiple_of
=
8
if
device
.
type
==
"cuda"
else
None
tokenized_inputs
=
tokenizer
(
...
...
@@ -89,7 +86,6 @@ class CausalLMBatch(Batch):
position_ids
=
position_ids
,
past_key_values
=
None
,
all_input_ids
=
all_input_ids
,
all_logprobs
=
all_logprobs
,
input_lengths
=
input_lengths
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
...
...
@@ -107,7 +103,6 @@ class CausalLMBatch(Batch):
requests
=
[]
input_lengths
=
[]
all_input_ids
=
[]
all_logprobs
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
...
...
@@ -124,7 +119,6 @@ class CausalLMBatch(Batch):
requests
.
extend
(
batch
.
requests
)
input_lengths
.
extend
(
batch
.
input_lengths
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
all_logprobs
.
extend
(
batch
.
all_logprobs
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
...
...
@@ -225,7 +219,6 @@ class CausalLMBatch(Batch):
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
all_input_ids
=
all_input_ids
,
all_logprobs
=
all_logprobs
,
input_lengths
=
input_lengths
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
...
...
@@ -234,6 +227,9 @@ class CausalLMBatch(Batch):
keys_head_dim_last
=
batches
[
0
].
keys_head_dim_last
,
)
def
__len__
(
self
):
return
len
(
self
.
requests
)
class
CausalLM
(
Model
):
def
__init__
(
self
,
model_name
:
str
,
quantize
=
False
):
...
...
@@ -289,7 +285,7 @@ class CausalLM(Model):
def
generate_token
(
self
,
batch
:
CausalLMBatch
)
->
Tuple
[
List
[
Generat
edText
],
Optional
[
CausalLMBatch
]]:
)
->
Tuple
[
List
[
Generat
ion
],
Optional
[
CausalLMBatch
]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager
=
(
torch
.
no_grad
if
self
.
device
.
type
==
"cpu"
else
torch
.
inference_mode
...
...
@@ -309,14 +305,13 @@ class CausalLM(Model):
next_batch_input_lengths
=
[]
next_batch_input_ids
=
[]
next_batch_all_input_ids
=
[]
next_batch_all_logprobs
=
[]
# Metadata
next_batch_size
=
0
next_batch_max_sequence_length
=
0
#
Finished reques
ts
generat
ed_text
s
:
List
[
Generat
edText
]
=
[]
#
Resul
ts
generat
ion
s
:
List
[
Generat
ion
]
=
[]
# Zipped iterator
iterator
=
zip
(
...
...
@@ -326,7 +321,6 @@ class CausalLM(Model):
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
batch
.
all_logprobs
,
)
# For each member of the batch
...
...
@@ -337,44 +331,36 @@ class CausalLM(Model):
next_token_chooser
,
stopping_criteria
,
all_input_ids
,
all_logprobs
,
)
in
enumerate
(
iterator
):
# Select next token
tokens
,
logprobs
=
next_token_chooser
(
all_input_ids
,
logits
)
next_token
=
tokens
[
-
1
].
view
(
1
,
1
)
next_token
_id
=
tokens
[
-
1
].
view
(
1
,
1
)
# Append next token to all tokens
all_input_ids
=
torch
.
cat
([
all_input_ids
,
next_token
])
all_input_ids
=
torch
.
cat
([
all_input_ids
,
next_token
_id
])
new_input_length
=
input_length
+
1
if
all_logprobs
is
None
:
# logprobs of all prompt tokens (except the first one) and the generated token
all_logprobs
=
logprobs
.
gather
(
1
,
all_input_ids
[
1
:])
else
:
# logprob of the generated token
next_token_logprob
=
logprobs
[
-
1
,
next_token
]
all_logprobs
=
torch
.
cat
([
all_logprobs
,
next_token_logprob
])
# Generated token
next_token_logprob
=
logprobs
[
-
1
,
next_token_id
]
next_token_id_squeezed
=
next_token_id
.
squeeze
()
next_token_text
=
self
.
tokenizer
.
decode
(
next_token_id_squeezed
,
clean_up_tokenization_spaces
=
False
,
skip_special_tokens
=
False
,
)
# Evaluate stopping criteria
stop
,
reason
=
stopping_criteria
(
next_token
.
squeeze
(),
self
.
tokenizer
.
decode
(
next_token
.
squeeze
(),
clean_up_tokenization_spaces
=
False
),
next_token_id_squeezed
,
next_token_text
,
)
if
stop
:
# Decode generated tokens
generated_text
=
self
.
decode
(
all_input_ids
[
-
stopping_criteria
.
current_tokens
:,
0
]
)
output_text
=
request
.
inputs
+
generated_text
# Slice with input_length to remove padding
token_ids
=
all_input_ids
[
-
new_input_length
:]
tokens
=
self
.
tokenizer
.
batch_decode
(
token_ids
)
# Add NaN for the first prompt token
logprobs
=
[
float
(
"nan"
)]
+
all_logprobs
[
-
input_length
:].
squeeze
(
1
).
tolist
()
# Get seed
if
isinstance
(
next_token_chooser
.
choice
,
Sampling
):
...
...
@@ -382,39 +368,58 @@ class CausalLM(Model):
else
:
seed
=
None
# Add to the list of finished generations with the original request
generated_texts
.
append
(
GeneratedText
(
request
=
request
,
output_text
=
output_text
,
generated_tokens
=
stopping_criteria
.
current_tokens
,
tokens
=
tokens
,
token_ids
=
token_ids
.
squeeze
(
1
).
tolist
(),
logprobs
=
logprobs
,
reason
=
reason
,
seed
=
seed
,
)
generated_text
=
GeneratedText
(
output_text
,
stopping_criteria
.
current_tokens
,
reason
,
seed
)
# add to the next batch
else
:
# Keep request in the batch
generated_text
=
None
next_batch_keep_indices
.
append
(
i
)
next_batch_input_ids
.
append
(
next_token
)
next_batch_input_ids
.
append
(
next_token
_id
)
next_batch_all_input_ids
.
append
(
all_input_ids
)
next_batch_all_logprobs
.
append
(
all_logprobs
)
next_batch_size
+=
1
next_batch_input_lengths
.
append
(
new_input_length
)
next_batch_max_sequence_length
=
max
(
next_batch_max_sequence_length
,
new_input_length
)
# Prefill
if
stopping_criteria
.
current_tokens
==
1
:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs
=
[
float
(
"nan"
)]
+
logprobs
.
gather
(
1
,
all_input_ids
[
1
:]
).
squeeze
(
1
)[
-
new_input_length
:
-
1
].
tolist
()
prefill_token_ids
=
all_input_ids
[
-
new_input_length
:
-
1
]
prefill_texts
=
self
.
tokenizer
.
batch_decode
(
prefill_token_ids
,
clean_up_tokenization_spaces
=
False
,
skip_special_tokens
=
False
,
)
prefill_tokens
=
PrefillTokens
(
prefill_token_ids
,
prefill_logprobs
,
prefill_texts
)
else
:
prefill_tokens
=
None
generation
=
Generation
(
request
.
id
,
prefill_tokens
,
next_token_id_squeezed
,
next_token_logprob
,
next_token_text
,
generated_text
,
)
generations
.
append
(
generation
)
# We finished all generations in the batch; there is no next batch
if
not
next_batch_keep_indices
:
return
generat
ed_text
s
,
None
return
generat
ion
s
,
None
next_batch_input_ids
=
torch
.
cat
(
next_batch_input_ids
,
dim
=
0
)
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if
g
en
erated_texts
:
if
l
en
(
next_batch_keep_indices
)
!=
len
(
batch
)
:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_attention_mask
=
batch
.
attention_mask
[
next_batch_keep_indices
]
next_batch_position_ids
=
batch
.
position_ids
[
next_batch_keep_indices
]
...
...
@@ -461,7 +466,6 @@ class CausalLM(Model):
position_ids
=
next_batch_position_ids
,
past_key_values
=
next_batch_past_key_values
,
all_input_ids
=
next_batch_all_input_ids
,
all_logprobs
=
next_batch_all_logprobs
,
input_lengths
=
next_batch_input_lengths
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
...
...
@@ -469,4 +473,4 @@ class CausalLM(Model):
max_sequence_length
=
next_batch_max_sequence_length
,
keys_head_dim_last
=
batch
.
keys_head_dim_last
,
)
return
generat
ed_text
s
,
next_batch
return
generat
ion
s
,
next_batch
server/text_generation/models/seq2seq_lm.py
View file @
017a2a8c
...
...
@@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokeniz
from
typing
import
Optional
,
Tuple
,
List
,
Type
from
text_generation.models
import
Model
from
text_generation.models.types
import
GeneratedText
,
Batch
from
text_generation.models.types
import
GeneratedText
,
Batch
,
Generation
,
PrefillTokens
from
text_generation.pb
import
generate_pb2
from
text_generation.utils
import
NextTokenChooser
,
StoppingCriteria
,
Sampling
...
...
@@ -30,7 +30,6 @@ class Seq2SeqLMBatch(Batch):
# Lengths of all generations present in the batch
input_lengths
:
List
[
int
]
decoder_input_lengths
:
List
[
int
]
decoder_logprobs
:
List
[
Optional
[
torch
.
Tensor
]]
# Generation helpers
next_token_choosers
:
List
[
NextTokenChooser
]
...
...
@@ -64,7 +63,6 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids
=
[]
decoder_input_lengths
=
[]
decoder_logprobs
=
[]
# Parse batch
for
r
in
pb
.
requests
:
...
...
@@ -77,7 +75,6 @@ class Seq2SeqLMBatch(Batch):
stopping_criterias
.
append
(
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
)
decoder_logprobs
.
append
(
None
)
# Tokenize batch
pad_to_multiple_of
=
8
if
device
.
type
==
"cuda"
else
None
...
...
@@ -102,7 +99,6 @@ class Seq2SeqLMBatch(Batch):
past_key_values
=
None
,
input_lengths
=
input_lengths
,
decoder_input_lengths
=
decoder_input_lengths
,
decoder_logprobs
=
decoder_logprobs
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
len
(
pb
.
requests
),
...
...
@@ -125,7 +121,6 @@ class Seq2SeqLMBatch(Batch):
requests
=
[]
input_lengths
=
[]
decoder_input_lengths
=
[]
decoder_logprobs
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
...
...
@@ -146,7 +141,6 @@ class Seq2SeqLMBatch(Batch):
requests
.
extend
(
batch
.
requests
)
input_lengths
.
extend
(
batch
.
input_lengths
)
decoder_input_lengths
.
extend
(
batch
.
decoder_input_lengths
)
decoder_logprobs
.
extend
(
batch
.
decoder_logprobs
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
...
...
@@ -283,7 +277,6 @@ class Seq2SeqLMBatch(Batch):
past_key_values
=
past_key_values
,
input_lengths
=
input_lengths
,
decoder_input_lengths
=
decoder_input_lengths
,
decoder_logprobs
=
decoder_logprobs
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
total_batch_size
,
...
...
@@ -291,6 +284,9 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length
=
max_decoder_input_length
,
)
def
__len__
(
self
):
return
len
(
self
.
requests
)
class
Seq2SeqLM
(
Model
):
def
__init__
(
self
,
model_name
:
str
,
quantize
=
False
):
...
...
@@ -364,7 +360,7 @@ class Seq2SeqLM(Model):
def
generate_token
(
self
,
batch
:
Seq2SeqLMBatch
)
->
Tuple
[
List
[
Generat
edText
],
Optional
[
Seq2SeqLMBatch
]]:
)
->
Tuple
[
List
[
Generat
ion
],
Optional
[
Seq2SeqLMBatch
]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager
=
(
torch
.
no_grad
if
self
.
device
.
type
==
"cpu"
else
torch
.
inference_mode
...
...
@@ -386,7 +382,6 @@ class Seq2SeqLM(Model):
next_batch_input_lengths
=
[]
next_batch_decoder_input_ids
=
[]
next_batch_decoder_input_lengths
=
[]
next_batch_decoder_logprobs
=
[]
# Metadata
next_batch_size
=
0
...
...
@@ -394,14 +389,13 @@ class Seq2SeqLM(Model):
next_batch_max_decoder_input_length
=
0
# Finished requests
generat
ed_text
s
:
List
[
Generat
edText
]
=
[]
generat
ion
s
:
List
[
Generat
ion
]
=
[]
# Zipped iterator
iterator
=
zip
(
batch
.
requests
,
batch
.
input_lengths
,
batch
.
decoder_input_lengths
,
batch
.
decoder_logprobs
,
logits
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
...
...
@@ -414,7 +408,6 @@ class Seq2SeqLM(Model):
request
,
input_length
,
decoder_input_length
,
decoder_logprobs
,
logits
,
next_token_chooser
,
stopping_criteria
,
...
...
@@ -422,35 +415,28 @@ class Seq2SeqLM(Model):
decoder_input_ids
,
)
in
enumerate
(
iterator
):
# Select next token
next_token
,
logprobs
=
next_token_chooser
(
decoder_input_ids
,
logits
)
next_token
_id
,
logprobs
=
next_token_chooser
(
decoder_input_ids
,
logits
)
# Append next token to decoder tokens
decoder_input_ids
=
torch
.
cat
([
decoder_input_ids
,
next_token
])
decoder_input_ids
=
torch
.
cat
([
decoder_input_ids
,
next_token
_id
])
new_decoder_input_length
=
decoder_input_length
+
1
next_token_logprob
=
logprobs
[
-
1
,
next_token
]
if
decoder_logprobs
is
None
:
decoder_logprobs
=
next_token_logprob
else
:
decoder_logprobs
=
torch
.
cat
([
decoder_logprobs
,
next_token_logprob
])
# Generated token
next_token_logprob
=
logprobs
[
-
1
,
next_token_id
]
next_token_id_squeezed
=
next_token_id
.
squeeze
()
next_token_text
=
self
.
tokenizer
.
decode
(
next_token_id_squeezed
,
clean_up_tokenization_spaces
=
False
,
skip_special_tokens
=
False
,
)
# Evaluate stopping criteria
stop
,
reason
=
stopping_criteria
(
next_token
.
squeeze
(),
self
.
tokenizer
.
decode
(
next_token
.
squeeze
(),
clean_up_tokenization_spaces
=
False
),
)
stop
,
reason
=
stopping_criteria
(
next_token_id
,
next_token_text
)
if
stop
:
# Slice with decoder_input_length to remove padding
# Decode all tokens
token_ids
=
decoder_input_ids
[
-
new_decoder_input_length
:]
output_text
=
self
.
decode
(
token_ids
)
tokens
=
self
.
tokenizer
.
batch_decode
(
token_ids
)
# Add NaN for the bos token
logprobs
=
[
float
(
"nan"
)]
+
decoder_logprobs
[
-
decoder_input_length
:
].
tolist
()
output_text
=
self
.
decode
(
decoder_input_ids
[
-
new_decoder_input_length
:])
# Get seed
if
isinstance
(
next_token_chooser
.
choice
,
Sampling
):
...
...
@@ -458,27 +444,17 @@ class Seq2SeqLM(Model):
else
:
seed
=
None
# Add to the list of finished generations with the original request
generated_texts
.
append
(
GeneratedText
(
request
=
request
,
output_text
=
output_text
,
generated_tokens
=
stopping_criteria
.
current_tokens
,
tokens
=
tokens
,
token_ids
=
token_ids
.
tolist
(),
logprobs
=
logprobs
,
reason
=
reason
,
seed
=
seed
,
)
generated_text
=
GeneratedText
(
output_text
,
stopping_criteria
.
current_tokens
,
reason
,
seed
)
# add to the next batch
else
:
# Keep request in the batch
generated_text
=
None
next_batch_keep_indices
.
append
(
i
)
next_batch_decoder_input_ids
.
append
(
decoder_input_ids
.
unsqueeze
(
0
))
next_batch_size
+=
1
next_batch_input_lengths
.
append
(
input_length
)
next_batch_decoder_input_lengths
.
append
(
new_decoder_input_length
)
next_batch_decoder_logprobs
.
append
(
decoder_logprobs
)
next_batch_max_input_length
=
max
(
next_batch_max_input_length
,
input_length
)
...
...
@@ -486,14 +462,39 @@ class Seq2SeqLM(Model):
next_batch_max_decoder_input_length
,
new_decoder_input_length
)
# Prefill
if
stopping_criteria
.
current_tokens
==
1
:
prefill_token_ids
=
decoder_input_ids
[
-
new_decoder_input_length
:
-
1
]
prefill_texts
=
self
.
tokenizer
.
batch_decode
(
prefill_token_ids
,
clean_up_tokenization_spaces
=
False
,
skip_special_tokens
=
False
,
)
prefill_tokens
=
PrefillTokens
(
prefill_token_ids
,
[
float
(
"nan"
)],
prefill_texts
)
else
:
prefill_tokens
=
None
generation
=
Generation
(
request
.
id
,
prefill_tokens
,
next_token_id_squeezed
,
next_token_logprob
,
next_token_text
,
generated_text
,
)
generations
.
append
(
generation
)
# We finished all generations in the batch; there is no next batch
if
not
next_batch_keep_indices
:
return
generat
ed_text
s
,
None
return
generat
ion
s
,
None
next_batch_decoder_input_ids
=
torch
.
cat
(
next_batch_decoder_input_ids
)
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if
g
en
erated_texts
:
if
l
en
(
next_batch_keep_indices
)
!=
len
(
batch
)
:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids
=
batch
.
input_ids
[
next_batch_keep_indices
]
next_batch_attention_mask
=
batch
.
attention_mask
[
next_batch_keep_indices
]
...
...
@@ -551,11 +552,10 @@ class Seq2SeqLM(Model):
past_key_values
=
next_batch_past_key_values
,
input_lengths
=
next_batch_input_lengths
,
decoder_input_lengths
=
next_batch_decoder_input_lengths
,
decoder_logprobs
=
next_batch_decoder_logprobs
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
size
=
next_batch_size
,
max_input_length
=
next_batch_max_input_length
,
max_decoder_input_length
=
next_batch_max_decoder_input_length
,
)
return
generat
ed_text
s
,
next_batch
return
generat
ion
s
,
next_batch
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment