Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
218c9ada
Unverified
Commit
218c9ada
authored
May 24, 2023
by
OlivierDehaene
Committed by
GitHub
May 24, 2023
Browse files
feat: decrease IPC proto size (#367)
Closes #307 #308
parent
d31562f3
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
108 additions
and
88 deletions
+108
-88
benchmark/src/generation.rs
benchmark/src/generation.rs
+3
-3
proto/generate.proto
proto/generate.proto
+17
-6
router/client/src/client.rs
router/client/src/client.rs
+9
-6
router/client/src/lib.rs
router/client/src/lib.rs
+2
-2
router/client/src/sharded_client.rs
router/client/src/sharded_client.rs
+14
-11
router/src/infer.rs
router/src/infer.rs
+9
-9
server/tests/models/test_bloom.py
server/tests/models/test_bloom.py
+3
-3
server/tests/models/test_causal_lm.py
server/tests/models/test_causal_lm.py
+3
-3
server/tests/models/test_seq2seq_lm.py
server/tests/models/test_seq2seq_lm.py
+3
-3
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+12
-10
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+17
-14
server/text_generation_server/models/seq2seq_lm.py
server/text_generation_server/models/seq2seq_lm.py
+13
-13
server/text_generation_server/models/types.py
server/text_generation_server/models/types.py
+2
-2
server/text_generation_server/server.py
server/text_generation_server/server.py
+1
-3
No files found.
benchmark/src/generation.rs
View file @
218c9ada
use
std
::
time
::{
Duration
,
Instant
};
use
std
::
time
::{
Duration
,
Instant
};
use
text_generation_client
::{
use
text_generation_client
::{
Batch
,
ClientError
,
NextTokenChooserParameters
,
Request
,
ShardedClient
,
Batch
,
CachedBatch
,
ClientError
,
NextTokenChooserParameters
,
Request
,
ShardedClient
,
StoppingCriteriaParameters
,
StoppingCriteriaParameters
,
};
};
use
tokenizers
::{
Tokenizer
,
TruncationDirection
};
use
tokenizers
::{
Tokenizer
,
TruncationDirection
};
...
@@ -126,7 +126,7 @@ async fn prefill(
...
@@ -126,7 +126,7 @@ async fn prefill(
batch_size
:
u32
,
batch_size
:
u32
,
decode_length
:
u32
,
decode_length
:
u32
,
client
:
&
mut
ShardedClient
,
client
:
&
mut
ShardedClient
,
)
->
Result
<
(
Prefill
,
Batch
),
ClientError
>
{
)
->
Result
<
(
Prefill
,
Cached
Batch
),
ClientError
>
{
// Create requests
// Create requests
let
requests
=
(
0
..
batch_size
)
let
requests
=
(
0
..
batch_size
)
.map
(|
id
|
Request
{
.map
(|
id
|
Request
{
...
@@ -180,7 +180,7 @@ async fn prefill(
...
@@ -180,7 +180,7 @@ async fn prefill(
}
}
/// Run a full decode
/// Run a full decode
async
fn
decode
(
batch
:
Batch
,
client
:
&
mut
ShardedClient
)
->
Result
<
Decode
,
ClientError
>
{
async
fn
decode
(
batch
:
Cached
Batch
,
client
:
&
mut
ShardedClient
)
->
Result
<
Decode
,
ClientError
>
{
let
mut
decode_length
=
0
;
let
mut
decode_length
=
0
;
let
batch_size
=
batch
.size
;
let
batch_size
=
batch
.size
;
...
...
proto/generate.proto
View file @
218c9ada
...
@@ -100,6 +100,17 @@ message Batch {
...
@@ -100,6 +100,17 @@ message Batch {
uint32
max_tokens
=
4
;
uint32
max_tokens
=
4
;
}
}
message
CachedBatch
{
/// Batch ID
uint64
id
=
1
;
/// Individual requests ids
repeated
uint64
request_ids
=
2
;
/// Batch size (==len(requests))
uint32
size
=
3
;
/// Maximum number of tokens this batch will grow to
uint32
max_tokens
=
4
;
}
enum
FinishReason
{
enum
FinishReason
{
FINISH_REASON_LENGTH
=
0
;
FINISH_REASON_LENGTH
=
0
;
FINISH_REASON_EOS_TOKEN
=
1
;
FINISH_REASON_EOS_TOKEN
=
1
;
...
@@ -140,19 +151,19 @@ message Generation {
...
@@ -140,19 +151,19 @@ message Generation {
/// Is it a special token
/// Is it a special token
bool
token_is_special
=
6
;
bool
token_is_special
=
6
;
/// Complete generated text
/// Complete generated text
GeneratedText
generated_text
=
7
;
optional
GeneratedText
generated_text
=
7
;
}
}
message
FilterBatchRequest
{
message
FilterBatchRequest
{
/// Batch ID
/// Batch ID
uint64
batch_id
=
1
;
uint64
batch_id
=
1
;
/// Requests to keep
/// Requests to keep
repeated
Request
keep_
requests
=
2
;
repeated
uint64
request
_id
s
=
2
;
}
}
message
FilterBatchResponse
{
message
FilterBatchResponse
{
/// Filtered Batch (cached)
/// Filtered Batch (cached)
Batch
batch
=
1
;
Cached
Batch
batch
=
1
;
}
}
...
@@ -165,17 +176,17 @@ message PrefillResponse {
...
@@ -165,17 +176,17 @@ message PrefillResponse {
/// Generation
/// Generation
repeated
Generation
generations
=
1
;
repeated
Generation
generations
=
1
;
/// Next batch (cached)
/// Next batch (cached)
optional
Batch
batch
=
2
;
optional
Cached
Batch
batch
=
2
;
}
}
message
DecodeRequest
{
message
DecodeRequest
{
/// Cached batches
/// Cached batches
repeated
Batch
batches
=
1
;
repeated
Cached
Batch
batches
=
1
;
}
}
message
DecodeResponse
{
message
DecodeResponse
{
/// Decodes
/// Decodes
repeated
Generation
generations
=
1
;
repeated
Generation
generations
=
1
;
/// Next batch (cached)
/// Next batch (cached)
optional
Batch
batch
=
2
;
optional
Cached
Batch
batch
=
2
;
}
}
router/client/src/client.rs
View file @
218c9ada
...
@@ -83,11 +83,11 @@ impl Client {
...
@@ -83,11 +83,11 @@ impl Client {
pub
async
fn
filter_batch
(
pub
async
fn
filter_batch
(
&
mut
self
,
&
mut
self
,
batch_id
:
u64
,
batch_id
:
u64
,
keep_
requests
:
Vec
<
Request
>
,
request
_id
s
:
Vec
<
u64
>
,
)
->
Result
<
Option
<
Batch
>>
{
)
->
Result
<
Option
<
Cached
Batch
>>
{
let
request
=
tonic
::
Request
::
new
(
FilterBatchRequest
{
let
request
=
tonic
::
Request
::
new
(
FilterBatchRequest
{
batch_id
,
batch_id
,
keep_
requests
,
request
_id
s
,
})
})
.inject_context
();
.inject_context
();
let
filtered_batch
=
self
.stub
.filter_batch
(
request
)
.await
?
.into_inner
();
let
filtered_batch
=
self
.stub
.filter_batch
(
request
)
.await
?
.into_inner
();
...
@@ -99,7 +99,10 @@ impl Client {
...
@@ -99,7 +99,10 @@ impl Client {
/// Returns Generation for each request in batch
/// Returns Generation for each request in batch
/// and the next cached batch
/// and the next cached batch
#[instrument(skip_all,
fields(id
=
&
batch
.
id,
size
=
&
batch
.
size))]
#[instrument(skip_all,
fields(id
=
&
batch
.
id,
size
=
&
batch
.
size))]
pub
async
fn
prefill
(
&
mut
self
,
batch
:
Batch
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
Batch
>
)
>
{
pub
async
fn
prefill
(
&
mut
self
,
batch
:
Batch
,
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
CachedBatch
>
)
>
{
let
request
=
tonic
::
Request
::
new
(
PrefillRequest
{
batch
:
Some
(
batch
)
})
.inject_context
();
let
request
=
tonic
::
Request
::
new
(
PrefillRequest
{
batch
:
Some
(
batch
)
})
.inject_context
();
let
response
=
self
.stub
.prefill
(
request
)
.await
?
.into_inner
();
let
response
=
self
.stub
.prefill
(
request
)
.await
?
.into_inner
();
Ok
((
response
.generations
,
response
.batch
))
Ok
((
response
.generations
,
response
.batch
))
...
@@ -112,8 +115,8 @@ impl Client {
...
@@ -112,8 +115,8 @@ impl Client {
#[instrument(skip_all,
fields(size
=
batches
.
iter()
.
map(
|
batch
|{
batch
.
size
}
)
.
sum::
<
u32
>
()))]
#[instrument(skip_all,
fields(size
=
batches
.
iter()
.
map(
|
batch
|{
batch
.
size
}
)
.
sum::
<
u32
>
()))]
pub
async
fn
decode
(
pub
async
fn
decode
(
&
mut
self
,
&
mut
self
,
batches
:
Vec
<
Batch
>
,
batches
:
Vec
<
Cached
Batch
>
,
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
Batch
>
)
>
{
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
Cached
Batch
>
)
>
{
let
request
=
tonic
::
Request
::
new
(
DecodeRequest
{
batches
})
.inject_context
();
let
request
=
tonic
::
Request
::
new
(
DecodeRequest
{
batches
})
.inject_context
();
let
response
=
self
.stub
.decode
(
request
)
.await
?
.into_inner
();
let
response
=
self
.stub
.decode
(
request
)
.await
?
.into_inner
();
Ok
((
response
.generations
,
response
.batch
))
Ok
((
response
.generations
,
response
.batch
))
...
...
router/client/src/lib.rs
View file @
218c9ada
...
@@ -9,8 +9,8 @@ pub use client::Client;
...
@@ -9,8 +9,8 @@ pub use client::Client;
pub
use
pb
::
generate
::
v1
::
HealthResponse
;
pub
use
pb
::
generate
::
v1
::
HealthResponse
;
pub
use
pb
::
generate
::
v1
::
InfoResponse
as
ShardInfo
;
pub
use
pb
::
generate
::
v1
::
InfoResponse
as
ShardInfo
;
pub
use
pb
::
generate
::
v1
::{
pub
use
pb
::
generate
::
v1
::{
Batch
,
FinishReason
,
GeneratedText
,
Generation
,
NextTokenChooserParameters
,
PrefillTokens
,
Batch
,
CachedBatch
,
FinishReason
,
GeneratedText
,
Generation
,
NextTokenChooserParameters
,
Request
,
StoppingCriteriaParameters
,
PrefillTokens
,
Request
,
StoppingCriteriaParameters
,
};
};
pub
use
sharded_client
::
ShardedClient
;
pub
use
sharded_client
::
ShardedClient
;
use
thiserror
::
Error
;
use
thiserror
::
Error
;
...
...
router/client/src/sharded_client.rs
View file @
218c9ada
/// Multi shard Client
/// Multi shard Client
use
crate
::{
Batch
,
Client
,
Generation
,
HealthResponse
,
Request
,
ShardInfo
};
use
crate
::{
Batch
,
CachedBatch
,
Client
,
Generation
,
HealthResponse
,
ShardInfo
};
use
crate
::{
ClientError
,
Result
};
use
crate
::{
ClientError
,
Result
};
use
futures
::
future
::
join_all
;
use
futures
::
future
::
join_all
;
use
tonic
::
transport
::
Uri
;
use
tonic
::
transport
::
Uri
;
...
@@ -76,12 +76,12 @@ impl ShardedClient {
...
@@ -76,12 +76,12 @@ impl ShardedClient {
pub
async
fn
filter_batch
(
pub
async
fn
filter_batch
(
&
mut
self
,
&
mut
self
,
batch_id
:
u64
,
batch_id
:
u64
,
keep_
requests
:
Vec
<
Request
>
,
request
_id
s
:
Vec
<
u64
>
,
)
->
Result
<
Option
<
Batch
>>
{
)
->
Result
<
Option
<
Cached
Batch
>>
{
let
futures
:
Vec
<
_
>
=
self
let
futures
:
Vec
<
_
>
=
self
.clients
.clients
.iter_mut
()
.iter_mut
()
.map
(|
client
|
Box
::
pin
(
client
.filter_batch
(
batch_id
,
keep_
requests
.clone
())))
.map
(|
client
|
Box
::
pin
(
client
.filter_batch
(
batch_id
,
request
_id
s
.clone
())))
.collect
();
.collect
();
// all shards return the same message
// all shards return the same message
join_all
(
futures
)
.await
.pop
()
.unwrap
()
join_all
(
futures
)
.await
.pop
()
.unwrap
()
...
@@ -92,13 +92,16 @@ impl ShardedClient {
...
@@ -92,13 +92,16 @@ impl ShardedClient {
/// Returns Generation for each request in batch
/// Returns Generation for each request in batch
/// and the next cached batch
/// and the next cached batch
#[instrument(skip_all,
fields(id
=
&
batch
.
id,
size
=
&
batch
.
size))]
#[instrument(skip_all,
fields(id
=
&
batch
.
id,
size
=
&
batch
.
size))]
pub
async
fn
prefill
(
&
mut
self
,
batch
:
Batch
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
Batch
>
)
>
{
pub
async
fn
prefill
(
&
mut
self
,
batch
:
Batch
,
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
CachedBatch
>
)
>
{
let
futures
:
Vec
<
_
>
=
self
let
futures
:
Vec
<
_
>
=
self
.clients
.clients
.iter_mut
()
.iter_mut
()
.map
(|
client
|
Box
::
pin
(
client
.prefill
(
batch
.clone
())))
.map
(|
client
|
Box
::
pin
(
client
.prefill
(
batch
.clone
())))
.collect
();
.collect
();
let
results
:
Result
<
Vec
<
(
Vec
<
Generation
>
,
Option
<
Batch
>
)
>>
=
let
results
:
Result
<
Vec
<
(
Vec
<
Generation
>
,
Option
<
Cached
Batch
>
)
>>
=
join_all
(
futures
)
.await
.into_iter
()
.collect
();
join_all
(
futures
)
.await
.into_iter
()
.collect
();
merge_generations
(
results
?
)
merge_generations
(
results
?
)
}
}
...
@@ -110,14 +113,14 @@ impl ShardedClient {
...
@@ -110,14 +113,14 @@ impl ShardedClient {
#[instrument(skip_all,
fields(size
=
batches
.
iter()
.
map(
|
batch
|{
batch
.
size
}
)
.
sum::
<
u32
>
()))]
#[instrument(skip_all,
fields(size
=
batches
.
iter()
.
map(
|
batch
|{
batch
.
size
}
)
.
sum::
<
u32
>
()))]
pub
async
fn
decode
(
pub
async
fn
decode
(
&
mut
self
,
&
mut
self
,
batches
:
Vec
<
Batch
>
,
batches
:
Vec
<
Cached
Batch
>
,
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
Batch
>
)
>
{
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
Cached
Batch
>
)
>
{
let
futures
:
Vec
<
_
>
=
self
let
futures
:
Vec
<
_
>
=
self
.clients
.clients
.iter_mut
()
.iter_mut
()
.map
(|
client
|
Box
::
pin
(
client
.decode
(
batches
.clone
())))
.map
(|
client
|
Box
::
pin
(
client
.decode
(
batches
.clone
())))
.collect
();
.collect
();
let
results
:
Result
<
Vec
<
(
Vec
<
Generation
>
,
Option
<
Batch
>
)
>>
=
let
results
:
Result
<
Vec
<
(
Vec
<
Generation
>
,
Option
<
Cached
Batch
>
)
>>
=
join_all
(
futures
)
.await
.into_iter
()
.collect
();
join_all
(
futures
)
.await
.into_iter
()
.collect
();
merge_generations
(
results
?
)
merge_generations
(
results
?
)
}
}
...
@@ -125,8 +128,8 @@ impl ShardedClient {
...
@@ -125,8 +128,8 @@ impl ShardedClient {
/// Merge generations from the different model shards
/// Merge generations from the different model shards
fn
merge_generations
(
fn
merge_generations
(
mut
results
:
Vec
<
(
Vec
<
Generation
>
,
Option
<
Batch
>
)
>
,
mut
results
:
Vec
<
(
Vec
<
Generation
>
,
Option
<
Cached
Batch
>
)
>
,
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
Batch
>
)
>
{
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
Cached
Batch
>
)
>
{
let
(
mut
generations
,
next_batch
)
=
results
.pop
()
.ok_or
(
ClientError
::
EmptyResults
)
?
;
let
(
mut
generations
,
next_batch
)
=
results
.pop
()
.ok_or
(
ClientError
::
EmptyResults
)
?
;
for
(
mut
shard_generations
,
_
)
in
results
.into_iter
()
{
for
(
mut
shard_generations
,
_
)
in
results
.into_iter
()
{
...
...
router/src/infer.rs
View file @
218c9ada
...
@@ -12,7 +12,7 @@ use std::sync::{
...
@@ -12,7 +12,7 @@ use std::sync::{
Arc
,
Arc
,
};
};
use
text_generation_client
::{
use
text_generation_client
::{
Batch
,
ClientError
,
GeneratedText
,
Generation
,
PrefillTokens
,
ShardedClient
,
Batch
,
CachedBatch
,
ClientError
,
GeneratedText
,
Generation
,
PrefillTokens
,
ShardedClient
,
};
};
use
thiserror
::
Error
;
use
thiserror
::
Error
;
use
tokio
::
sync
::{
Notify
,
OwnedSemaphorePermit
,
Semaphore
,
TryAcquireError
};
use
tokio
::
sync
::{
Notify
,
OwnedSemaphorePermit
,
Semaphore
,
TryAcquireError
};
...
@@ -352,7 +352,7 @@ async fn prefill(
...
@@ -352,7 +352,7 @@ async fn prefill(
batch
:
Batch
,
batch
:
Batch
,
entries
:
&
mut
IntMap
<
u64
,
Entry
>
,
entries
:
&
mut
IntMap
<
u64
,
Entry
>
,
generation_health
:
&
Arc
<
AtomicBool
>
,
generation_health
:
&
Arc
<
AtomicBool
>
,
)
->
Option
<
Batch
>
{
)
->
Option
<
Cached
Batch
>
{
let
start_time
=
Instant
::
now
();
let
start_time
=
Instant
::
now
();
let
batch_id
=
batch
.id
;
let
batch_id
=
batch
.id
;
metrics
::
increment_counter!
(
"tgi_batch_inference_count"
,
"method"
=>
"prefill"
);
metrics
::
increment_counter!
(
"tgi_batch_inference_count"
,
"method"
=>
"prefill"
);
...
@@ -386,10 +386,10 @@ async fn prefill(
...
@@ -386,10 +386,10 @@ async fn prefill(
#[instrument(skip_all)]
#[instrument(skip_all)]
async
fn
decode
(
async
fn
decode
(
client
:
&
mut
ShardedClient
,
client
:
&
mut
ShardedClient
,
batches
:
Vec
<
Batch
>
,
batches
:
Vec
<
Cached
Batch
>
,
entries
:
&
mut
IntMap
<
u64
,
Entry
>
,
entries
:
&
mut
IntMap
<
u64
,
Entry
>
,
generation_health
:
&
Arc
<
AtomicBool
>
,
generation_health
:
&
Arc
<
AtomicBool
>
,
)
->
Option
<
Batch
>
{
)
->
Option
<
Cached
Batch
>
{
let
start_time
=
Instant
::
now
();
let
start_time
=
Instant
::
now
();
let
batch_ids
:
Vec
<
u64
>
=
batches
.iter
()
.map
(|
b
|
b
.id
)
.collect
();
let
batch_ids
:
Vec
<
u64
>
=
batches
.iter
()
.map
(|
b
|
b
.id
)
.collect
();
metrics
::
increment_counter!
(
"tgi_batch_inference_count"
,
"method"
=>
"decode"
);
metrics
::
increment_counter!
(
"tgi_batch_inference_count"
,
"method"
=>
"decode"
);
...
@@ -425,9 +425,9 @@ async fn decode(
...
@@ -425,9 +425,9 @@ async fn decode(
#[instrument(skip_all)]
#[instrument(skip_all)]
async
fn
filter_batch
(
async
fn
filter_batch
(
client
:
&
mut
ShardedClient
,
client
:
&
mut
ShardedClient
,
next_batch
:
Option
<
Batch
>
,
next_batch
:
Option
<
Cached
Batch
>
,
entries
:
&
IntMap
<
u64
,
Entry
>
,
entries
:
&
IntMap
<
u64
,
Entry
>
,
)
->
Option
<
Batch
>
{
)
->
Option
<
Cached
Batch
>
{
let
mut
batch
=
next_batch
?
;
let
mut
batch
=
next_batch
?
;
// No need to filter
// No need to filter
...
@@ -438,9 +438,9 @@ async fn filter_batch(
...
@@ -438,9 +438,9 @@ async fn filter_batch(
let
id
=
batch
.id
;
let
id
=
batch
.id
;
// Retain only requests that are still in entries
// Retain only requests that are still in entries
batch
.requests
.retain
(|
r
|
entries
.contains_key
(
&
r
.
id
));
batch
.request
_id
s
.retain
(|
id
|
entries
.contains_key
(
id
));
if
batch
.requests
.is_empty
()
{
if
batch
.request
_id
s
.is_empty
()
{
// All requests have been filtered out
// All requests have been filtered out
// Next batch is now empty
// Next batch is now empty
// Clear it from the Python shards cache
// Clear it from the Python shards cache
...
@@ -450,7 +450,7 @@ async fn filter_batch(
...
@@ -450,7 +450,7 @@ async fn filter_batch(
}
else
{
}
else
{
// Filter Python shard cache
// Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails
// We unwrap here as we need to panic since we cannot recover if this method fails
client
.filter_batch
(
id
,
batch
.requests
)
.await
.unwrap
()
client
.filter_batch
(
id
,
batch
.request
_id
s
)
.await
.unwrap
()
}
}
}
}
...
...
server/tests/models/test_bloom.py
View file @
218c9ada
...
@@ -178,7 +178,7 @@ def test_causal_lm_generate_token_completion_multi(
...
@@ -178,7 +178,7 @@ def test_causal_lm_generate_token_completion_multi(
# Copy stopping_criterias before filtering
# Copy stopping_criterias before filtering
stopping_criterias
=
default_multi_requests_bloom_batch
.
stopping_criterias
.
copy
()
stopping_criterias
=
default_multi_requests_bloom_batch
.
stopping_criterias
.
copy
()
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
0
]])
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
0
]
.
id
])
for
_
in
range
(
for
_
in
range
(
stopping_criterias
[
0
].
max_new_tokens
-
stopping_criterias
[
1
].
max_new_tokens
-
1
stopping_criterias
[
0
].
max_new_tokens
-
stopping_criterias
[
1
].
max_new_tokens
-
1
...
@@ -286,7 +286,7 @@ def test_batch_concatenate(
...
@@ -286,7 +286,7 @@ def test_batch_concatenate(
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
)
)
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
0
],
next_batch
.
requests
[
1
]])
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
0
]
.
id
,
next_batch
.
requests
[
1
]
.
id
])
for
_
in
range
(
for
_
in
range
(
default_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
default_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
...
@@ -309,7 +309,7 @@ def test_batch_concatenate(
...
@@ -309,7 +309,7 @@ def test_batch_concatenate(
==
default_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
==
default_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
)
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
1
]])
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
1
]
.
id
])
for
_
in
range
(
for
_
in
range
(
default_multi_requests_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
default_multi_requests_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
...
...
server/tests/models/test_causal_lm.py
View file @
218c9ada
...
@@ -178,7 +178,7 @@ def test_causal_lm_generate_token_completion_multi(
...
@@ -178,7 +178,7 @@ def test_causal_lm_generate_token_completion_multi(
default_multi_requests_causal_lm_batch
.
stopping_criterias
.
copy
()
default_multi_requests_causal_lm_batch
.
stopping_criterias
.
copy
()
)
)
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
0
]])
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
0
]
.
id
])
for
_
in
range
(
for
_
in
range
(
stopping_criterias
[
0
].
max_new_tokens
-
stopping_criterias
[
1
].
max_new_tokens
-
1
stopping_criterias
[
0
].
max_new_tokens
-
stopping_criterias
[
1
].
max_new_tokens
-
1
...
@@ -285,7 +285,7 @@ def test_batch_concatenate(
...
@@ -285,7 +285,7 @@ def test_batch_concatenate(
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
)
)
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
0
],
next_batch
.
requests
[
1
]])
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
0
]
.
id
,
next_batch
.
requests
[
1
]
.
id
])
for
_
in
range
(
for
_
in
range
(
default_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
default_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
...
@@ -306,7 +306,7 @@ def test_batch_concatenate(
...
@@ -306,7 +306,7 @@ def test_batch_concatenate(
==
default_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
==
default_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
)
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
1
]])
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
1
]
.
id
])
for
_
in
range
(
for
_
in
range
(
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
...
...
server/tests/models/test_seq2seq_lm.py
View file @
218c9ada
...
@@ -190,7 +190,7 @@ def test_seq2seq_lm_generate_token_completion_multi(
...
@@ -190,7 +190,7 @@ def test_seq2seq_lm_generate_token_completion_multi(
)
)
assert
generations
[
1
].
generated_text
.
generated_tokens
==
5
assert
generations
[
1
].
generated_text
.
generated_tokens
==
5
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
0
]])
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
0
]
.
id
])
generations
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
...
@@ -323,7 +323,7 @@ def test_batch_concatenate(
...
@@ -323,7 +323,7 @@ def test_batch_concatenate(
)
)
assert
generations
[
2
].
generated_text
.
generated_tokens
==
5
assert
generations
[
2
].
generated_text
.
generated_tokens
==
5
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
0
],
next_batch
.
requests
[
1
]])
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
0
]
.
id
,
next_batch
.
requests
[
1
]
.
id
])
generations
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
next_batch
is
not
None
...
@@ -333,7 +333,7 @@ def test_batch_concatenate(
...
@@ -333,7 +333,7 @@ def test_batch_concatenate(
assert
generations
[
0
].
request_id
==
default_seq2seq_lm_batch
.
requests
[
0
].
id
assert
generations
[
0
].
request_id
==
default_seq2seq_lm_batch
.
requests
[
0
].
id
assert
generations
[
0
].
generated_text
.
generated_tokens
==
7
assert
generations
[
0
].
generated_text
.
generated_tokens
==
7
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
1
]])
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
1
]
.
id
])
generations
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
next_batch
is
None
...
...
server/text_generation_server/models/causal_lm.py
View file @
218c9ada
...
@@ -53,10 +53,10 @@ class CausalLMBatch(Batch):
...
@@ -53,10 +53,10 @@ class CausalLMBatch(Batch):
# Past metadata
# Past metadata
keys_head_dim_last
:
bool
=
True
keys_head_dim_last
:
bool
=
True
def
to_pb
(
self
)
->
generate_pb2
.
Batch
:
def
to_pb
(
self
)
->
generate_pb2
.
Cached
Batch
:
return
generate_pb2
.
Batch
(
return
generate_pb2
.
Cached
Batch
(
id
=
self
.
batch_id
,
id
=
self
.
batch_id
,
request
s
=
self
.
requests
,
request
_ids
=
[
r
.
id
for
r
in
self
.
requests
]
,
size
=
len
(
self
),
size
=
len
(
self
),
max_tokens
=
self
.
max_tokens
,
max_tokens
=
self
.
max_tokens
,
)
)
...
@@ -143,16 +143,17 @@ class CausalLMBatch(Batch):
...
@@ -143,16 +143,17 @@ class CausalLMBatch(Batch):
)
)
@
tracer
.
start_as_current_span
(
"filter"
)
@
tracer
.
start_as_current_span
(
"filter"
)
def
filter
(
self
,
requests
:
List
[
generate_pb2
.
Reques
t
])
->
Optional
[
"CausalLMBatch"
]:
def
filter
(
self
,
request
_id
s
:
List
[
in
t
])
->
Optional
[
"CausalLMBatch"
]:
if
len
(
requests
)
==
0
:
if
len
(
request
_id
s
)
==
0
:
raise
ValueError
(
"Batch must have at least one request"
)
raise
ValueError
(
"Batch must have at least one request"
)
if
len
(
requests
)
==
len
(
self
):
if
len
(
request
_id
s
)
==
len
(
self
):
return
self
return
self
keep_indices
=
[]
keep_indices
=
[]
# New values after filtering
# New values after filtering
requests_idx_mapping
=
{}
requests_idx_mapping
=
{}
requests
=
[]
input_lengths
=
[]
input_lengths
=
[]
prefix_offsets
=
[]
prefix_offsets
=
[]
read_offsets
=
[]
read_offsets
=
[]
...
@@ -165,11 +166,12 @@ class CausalLMBatch(Batch):
...
@@ -165,11 +166,12 @@ class CausalLMBatch(Batch):
total_remaining_decode_tokens
=
0
total_remaining_decode_tokens
=
0
new_padding_right_offset
=
0
new_padding_right_offset
=
0
for
i
,
r
in
enumerate
(
requests
):
for
i
,
r
equest_id
in
enumerate
(
request
_id
s
):
idx
=
self
.
requests_idx_mapping
[
r
.
id
]
idx
=
self
.
requests_idx_mapping
[
r
equest_
id
]
requests_idx_mapping
[
r
.
id
]
=
i
requests_idx_mapping
[
r
equest_
id
]
=
i
keep_indices
.
append
(
idx
)
keep_indices
.
append
(
idx
)
requests
.
append
(
self
.
requests
[
idx
])
prefix_offsets
.
append
(
self
.
prefix_offsets
[
idx
])
prefix_offsets
.
append
(
self
.
prefix_offsets
[
idx
])
read_offsets
.
append
(
self
.
read_offsets
[
idx
])
read_offsets
.
append
(
self
.
read_offsets
[
idx
])
all_input_ids
.
append
(
self
.
all_input_ids
[
idx
])
all_input_ids
.
append
(
self
.
all_input_ids
[
idx
])
...
@@ -220,7 +222,7 @@ class CausalLMBatch(Batch):
...
@@ -220,7 +222,7 @@ class CausalLMBatch(Batch):
layer
[
1
]
=
past_values
[
keep_indices
,
:,
-
past_kv_length
:,
:]
layer
[
1
]
=
past_values
[
keep_indices
,
:,
-
past_kv_length
:,
:]
del
past_values
del
past_values
max_tokens
=
len
(
requests
)
*
max_input_length
+
total_remaining_decode_tokens
max_tokens
=
len
(
request
_id
s
)
*
max_input_length
+
total_remaining_decode_tokens
self
.
requests
=
requests
self
.
requests
=
requests
self
.
requests_idx_mapping
=
requests_idx_mapping
self
.
requests_idx_mapping
=
requests_idx_mapping
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
218c9ada
...
@@ -62,10 +62,10 @@ class FlashCausalLMBatch(Batch):
...
@@ -62,10 +62,10 @@ class FlashCausalLMBatch(Batch):
# Maximum number of tokens this batch will grow to
# Maximum number of tokens this batch will grow to
max_tokens
:
int
max_tokens
:
int
def
to_pb
(
self
)
->
generate_pb2
.
Batch
:
def
to_pb
(
self
)
->
generate_pb2
.
Cached
Batch
:
return
generate_pb2
.
Batch
(
return
generate_pb2
.
Cached
Batch
(
id
=
self
.
batch_id
,
id
=
self
.
batch_id
,
request
s
=
self
.
requests
,
request
_ids
=
[
r
.
id
for
r
in
self
.
requests
]
,
size
=
len
(
self
),
size
=
len
(
self
),
max_tokens
=
self
.
max_tokens
,
max_tokens
=
self
.
max_tokens
,
)
)
...
@@ -161,14 +161,14 @@ class FlashCausalLMBatch(Batch):
...
@@ -161,14 +161,14 @@ class FlashCausalLMBatch(Batch):
)
)
@
tracer
.
start_as_current_span
(
"filter"
)
@
tracer
.
start_as_current_span
(
"filter"
)
def
filter
(
self
,
requests
:
List
[
generate_pb2
.
Reques
t
])
->
"FlashCausalLMBatch"
:
def
filter
(
self
,
request
_id
s
:
List
[
in
t
])
->
"FlashCausalLMBatch"
:
if
len
(
requests
)
==
0
:
if
len
(
request
_id
s
)
==
0
:
raise
ValueError
(
"Batch must have at least one request"
)
raise
ValueError
(
"Batch must have at least one request"
)
# We assume that if len(requests) == len(self) then the requests are the same
# We assume that if len(requests) == len(self) then the requests are the same
if
len
(
requests
)
==
len
(
self
):
if
len
(
request
_id
s
)
==
len
(
self
):
return
self
return
self
single_request
=
len
(
requests
)
==
1
single_request
=
len
(
request
_id
s
)
==
1
# Cumulative length
# Cumulative length
cumulative_length
=
0
cumulative_length
=
0
...
@@ -176,16 +176,17 @@ class FlashCausalLMBatch(Batch):
...
@@ -176,16 +176,17 @@ class FlashCausalLMBatch(Batch):
# New values after filtering
# New values after filtering
requests_idx_mapping
=
{}
requests_idx_mapping
=
{}
input_ids
=
self
.
input_ids
.
new_empty
(
len
(
requests
))
input_ids
=
self
.
input_ids
.
new_empty
(
len
(
request
_id
s
))
position_ids
=
self
.
position_ids
.
new_empty
(
len
(
requests
))
position_ids
=
self
.
position_ids
.
new_empty
(
len
(
request
_id
s
))
# Create on CPU to only move to GPU once instead of at every copy
# Create on CPU to only move to GPU once instead of at every copy
cu_seqlens
=
torch
.
zeros
(
len
(
requests
)
+
1
,
dtype
=
torch
.
int32
)
cu_seqlens
=
torch
.
zeros
(
len
(
request
_id
s
)
+
1
,
dtype
=
torch
.
int32
)
cu_seqlens_q
=
torch
.
arange
(
cu_seqlens_q
=
torch
.
arange
(
0
,
len
(
requests
)
+
1
,
device
=
self
.
cu_seqlens_q
.
device
,
dtype
=
torch
.
int32
0
,
len
(
request
_id
s
)
+
1
,
device
=
self
.
cu_seqlens_q
.
device
,
dtype
=
torch
.
int32
)
)
max_seqlen
=
0
max_seqlen
=
0
past_key_values
=
[]
past_key_values
=
[]
requests
=
[]
all_input_ids
=
[]
all_input_ids
=
[]
all_input_ids_tensor
=
[]
all_input_ids_tensor
=
[]
...
@@ -198,9 +199,11 @@ class FlashCausalLMBatch(Batch):
...
@@ -198,9 +199,11 @@ class FlashCausalLMBatch(Batch):
max_tokens
=
0
max_tokens
=
0
for
i
,
r
in
enumerate
(
requests
):
for
i
,
request_id
in
enumerate
(
request_ids
):
idx
=
self
.
requests_idx_mapping
[
r
.
id
]
idx
=
self
.
requests_idx_mapping
[
request_id
]
requests_idx_mapping
[
r
.
id
]
=
i
requests_idx_mapping
[
request_id
]
=
i
requests
.
append
(
self
.
requests
[
idx
])
# Get length
# Get length
request_input_length
=
self
.
input_lengths
[
idx
]
request_input_length
=
self
.
input_lengths
[
idx
]
...
...
server/text_generation_server/models/seq2seq_lm.py
View file @
218c9ada
...
@@ -57,11 +57,11 @@ class Seq2SeqLMBatch(Batch):
...
@@ -57,11 +57,11 @@ class Seq2SeqLMBatch(Batch):
# Maximum number of tokens this batch will grow to
# Maximum number of tokens this batch will grow to
max_tokens
:
int
max_tokens
:
int
def
to_pb
(
self
)
->
generate_pb2
.
Batch
:
def
to_pb
(
self
)
->
generate_pb2
.
Cached
Batch
:
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.
Cached
Batch protobuf"""
return
generate_pb2
.
Batch
(
return
generate_pb2
.
Cached
Batch
(
id
=
self
.
batch_id
,
id
=
self
.
batch_id
,
request
s
=
self
.
requests
,
request
_ids
=
[
r
.
id
for
r
in
self
.
requests
]
,
size
=
len
(
self
),
size
=
len
(
self
),
max_tokens
=
self
.
max_tokens
,
max_tokens
=
self
.
max_tokens
,
)
)
...
@@ -152,18 +152,17 @@ class Seq2SeqLMBatch(Batch):
...
@@ -152,18 +152,17 @@ class Seq2SeqLMBatch(Batch):
)
)
@
tracer
.
start_as_current_span
(
"filter"
)
@
tracer
.
start_as_current_span
(
"filter"
)
def
filter
(
def
filter
(
self
,
request_ids
:
List
[
int
])
->
Optional
[
"Seq2SeqLMBatch"
]:
self
,
requests
:
List
[
generate_pb2
.
Request
]
if
len
(
request_ids
)
==
0
:
)
->
Optional
[
"Seq2SeqLMBatch"
]:
if
len
(
requests
)
==
0
:
raise
ValueError
(
"Batch must have at least one request"
)
raise
ValueError
(
"Batch must have at least one request"
)
if
len
(
requests
)
==
len
(
self
):
if
len
(
request
_id
s
)
==
len
(
self
):
return
self
return
self
keep_indices
=
[]
keep_indices
=
[]
# New values after filtering
# New values after filtering
requests_idx_mapping
=
{}
requests_idx_mapping
=
{}
requests
=
[]
input_lengths
=
[]
input_lengths
=
[]
decoder_input_lengths
=
[]
decoder_input_lengths
=
[]
prefix_offsets
=
[]
prefix_offsets
=
[]
...
@@ -180,11 +179,12 @@ class Seq2SeqLMBatch(Batch):
...
@@ -180,11 +179,12 @@ class Seq2SeqLMBatch(Batch):
total_remaining_decode_tokens
=
0
total_remaining_decode_tokens
=
0
for
i
,
r
in
enumerate
(
requests
):
for
i
,
r
equest_id
in
enumerate
(
request
_id
s
):
idx
=
self
.
requests_idx_mapping
[
r
.
id
]
idx
=
self
.
requests_idx_mapping
[
r
equest_
id
]
requests_idx_mapping
[
r
.
id
]
=
i
requests_idx_mapping
[
r
equest_
id
]
=
i
keep_indices
.
append
(
idx
)
keep_indices
.
append
(
idx
)
requests
.
append
(
self
.
requests
[
idx
])
prefix_offsets
.
append
(
self
.
prefix_offsets
[
idx
])
prefix_offsets
.
append
(
self
.
prefix_offsets
[
idx
])
read_offsets
.
append
(
self
.
read_offsets
[
idx
])
read_offsets
.
append
(
self
.
read_offsets
[
idx
])
...
@@ -239,7 +239,7 @@ class Seq2SeqLMBatch(Batch):
...
@@ -239,7 +239,7 @@ class Seq2SeqLMBatch(Batch):
layer
[
3
]
=
layer
[
3
][
keep_indices
,
:,
-
max_input_length
:]
layer
[
3
]
=
layer
[
3
][
keep_indices
,
:,
-
max_input_length
:]
max_tokens
=
(
max_tokens
=
(
len
(
requests
)
*
(
max_input_length
+
max_decoder_input_length
)
len
(
request
_id
s
)
*
(
max_input_length
+
max_decoder_input_length
)
+
remaining_decode_tokens
+
remaining_decode_tokens
)
)
...
...
server/text_generation_server/models/types.py
View file @
218c9ada
...
@@ -12,7 +12,7 @@ from text_generation_server.pb.generate_pb2 import FinishReason
...
@@ -12,7 +12,7 @@ from text_generation_server.pb.generate_pb2 import FinishReason
class
Batch
(
ABC
):
class
Batch
(
ABC
):
@
abstractmethod
@
abstractmethod
def
to_pb
(
self
)
->
generate_pb2
.
Batch
:
def
to_pb
(
self
)
->
generate_pb2
.
Cached
Batch
:
raise
NotImplementedError
raise
NotImplementedError
@
classmethod
@
classmethod
...
@@ -26,7 +26,7 @@ class Batch(ABC):
...
@@ -26,7 +26,7 @@ class Batch(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
filter
(
self
,
requests
:
List
[
generate_pb2
.
Reques
t
])
->
"Batch"
:
def
filter
(
self
,
request
_id
s
:
List
[
in
t
])
->
"Batch"
:
raise
NotImplementedError
raise
NotImplementedError
@
classmethod
@
classmethod
...
...
server/text_generation_server/server.py
View file @
218c9ada
...
@@ -42,15 +42,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
...
@@ -42,15 +42,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
self
.
cache
.
delete
(
request
.
id
)
self
.
cache
.
delete
(
request
.
id
)
else
:
else
:
self
.
cache
.
clear
()
self
.
cache
.
clear
()
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
return
generate_pb2
.
ClearCacheResponse
()
return
generate_pb2
.
ClearCacheResponse
()
async
def
FilterBatch
(
self
,
request
,
context
):
async
def
FilterBatch
(
self
,
request
,
context
):
batch
=
self
.
cache
.
pop
(
request
.
batch_id
)
batch
=
self
.
cache
.
pop
(
request
.
batch_id
)
if
batch
is
None
:
if
batch
is
None
:
raise
ValueError
(
f
"Batch ID
{
request
.
batch_id
}
not found in cache."
)
raise
ValueError
(
f
"Batch ID
{
request
.
batch_id
}
not found in cache."
)
filtered_batch
=
batch
.
filter
(
request
.
keep_
requests
)
filtered_batch
=
batch
.
filter
(
request
.
request
_id
s
)
self
.
cache
.
set
(
filtered_batch
)
self
.
cache
.
set
(
filtered_batch
)
return
generate_pb2
.
FilterBatchResponse
(
batch
=
filtered_batch
.
to_pb
())
return
generate_pb2
.
FilterBatchResponse
(
batch
=
filtered_batch
.
to_pb
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment