Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
ebc74d56
Unverified
Commit
ebc74d56
authored
Apr 24, 2023
by
OlivierDehaene
Committed by
GitHub
Apr 24, 2023
Browse files
feat(router): use number of tokens in batch as input for dynamic batching (#226)
Co-authored-by:
Nick Hill
<
nickhill@us.ibm.com
>
parent
98a3e0d1
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
400 additions
and
173 deletions
+400
-173
launcher/src/main.rs
launcher/src/main.rs
+19
-4
proto/generate.proto
proto/generate.proto
+17
-0
router/client/src/client.rs
router/client/src/client.rs
+16
-0
router/client/src/sharded_client.rs
router/client/src/sharded_client.rs
+17
-1
router/src/infer.rs
router/src/infer.rs
+89
-81
router/src/main.rs
router/src/main.rs
+16
-3
router/src/queue.rs
router/src/queue.rs
+77
-40
router/src/server.rs
router/src/server.rs
+6
-3
router/src/validation.rs
router/src/validation.rs
+8
-7
server/tests/models/test_bloom.py
server/tests/models/test_bloom.py
+1
-3
server/tests/models/test_causal_lm.py
server/tests/models/test_causal_lm.py
+4
-4
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+34
-7
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+26
-2
server/text_generation_server/models/galactica.py
server/text_generation_server/models/galactica.py
+5
-0
server/text_generation_server/models/seq2seq_lm.py
server/text_generation_server/models/seq2seq_lm.py
+55
-15
server/text_generation_server/server.py
server/text_generation_server/server.py
+10
-3
No files found.
launcher/src/main.rs
View file @
ebc74d56
...
...
@@ -39,8 +39,12 @@ struct Args {
max_input_length
:
usize
,
#[clap(default_value
=
"1512"
,
long,
env)]
max_total_tokens
:
usize
,
#[clap(default_value
=
"32"
,
long,
env)]
max_batch_size
:
usize
,
#[clap(long,
env)]
max_batch_size
:
Option
<
usize
>
,
#[clap(default_value
=
"1.2"
,
long,
env)]
waiting_served_ratio
:
f32
,
#[clap(default_value
=
"32000"
,
long,
env)]
max_batch_total_tokens
:
u32
,
#[clap(default_value
=
"20"
,
long,
env)]
max_waiting_tokens
:
usize
,
#[clap(default_value
=
"3000"
,
long,
short,
env)]
...
...
@@ -93,6 +97,8 @@ fn main() -> ExitCode {
max_input_length
,
max_total_tokens
,
max_batch_size
,
max_batch_total_tokens
,
waiting_served_ratio
,
max_waiting_tokens
,
port
,
shard_uds_path
,
...
...
@@ -380,8 +386,8 @@ fn main() -> ExitCode {
max_input_length
.to_string
(),
"--max-total-tokens"
.to_string
(),
max_total_tokens
.to_string
(),
"--
max-batch-size
"
.to_string
(),
max_batch_size
.to_string
(),
"--
waiting-served-ratio
"
.to_string
(),
waiting_served_ratio
.to_string
(),
"--max-waiting-tokens"
.to_string
(),
max_waiting_tokens
.to_string
(),
"--port"
.to_string
(),
...
...
@@ -392,6 +398,15 @@ fn main() -> ExitCode {
model_id
,
];
// Deprecate max_batch_size
if
let
Some
(
max_batch_size
)
=
max_batch_size
{
argv
.push
(
"--max-batch-size"
.to_string
());
argv
.push
(
max_batch_size
.to_string
())
}
else
{
argv
.push
(
"--max-batch-total-tokens"
.to_string
());
argv
.push
(
max_batch_total_tokens
.to_string
())
}
// Model optional revision
if
let
Some
(
ref
revision
)
=
revision
{
argv
.push
(
"--revision"
.to_string
());
...
...
proto/generate.proto
View file @
ebc74d56
...
...
@@ -9,6 +9,8 @@ service TextGenerationService {
rpc
ServiceDiscovery
(
ServiceDiscoveryRequest
)
returns
(
ServiceDiscoveryResponse
)
{}
/// Empties batch cache
rpc
ClearCache
(
ClearCacheRequest
)
returns
(
ClearCacheResponse
);
/// Remove requests from a cached batch
rpc
FilterBatch
(
FilterBatchRequest
)
returns
(
FilterBatchResponse
);
/// Prefill batch and decode first token
rpc
Prefill
(
PrefillRequest
)
returns
(
PrefillResponse
);
/// Decode token for a list of prefilled batches
...
...
@@ -89,6 +91,8 @@ message Batch {
repeated
Request
requests
=
2
;
/// Batch size (==len(requests))
uint32
size
=
3
;
/// Maximum number of tokens this batch will grow to
uint32
max_tokens
=
4
;
}
enum
FinishReason
{
...
...
@@ -134,6 +138,19 @@ message Generation {
GeneratedText
generated_text
=
7
;
}
message
FilterBatchRequest
{
/// Batch ID
uint64
batch_id
=
1
;
/// Requests to keep
repeated
Request
keep_requests
=
2
;
}
message
FilterBatchResponse
{
/// Filtered Batch (cached)
Batch
batch
=
1
;
}
message
PrefillRequest
{
/// Batch
Batch
batch
=
1
;
...
...
router/client/src/client.rs
View file @
ebc74d56
...
...
@@ -70,6 +70,22 @@ impl Client {
Ok
(())
}
/// Filter a cached batch
#[instrument(skip(self))]
pub
async
fn
filter_batch
(
&
mut
self
,
batch_id
:
u64
,
keep_requests
:
Vec
<
Request
>
,
)
->
Result
<
Option
<
Batch
>>
{
let
request
=
tonic
::
Request
::
new
(
FilterBatchRequest
{
batch_id
,
keep_requests
,
})
.inject_context
();
let
filtered_batch
=
self
.stub
.filter_batch
(
request
)
.await
?
.into_inner
();
Ok
(
filtered_batch
.batch
)
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
...
...
router/client/src/sharded_client.rs
View file @
ebc74d56
/// Multi shard Client
use
crate
::
Result
;
use
crate
::{
Batch
,
Client
,
Generation
,
ShardInfo
};
use
crate
::{
Batch
,
Client
,
Generation
,
Request
,
ShardInfo
};
use
futures
::
future
::
join_all
;
use
tonic
::
transport
::
Uri
;
use
tracing
::
instrument
;
...
...
@@ -59,6 +59,22 @@ impl ShardedClient {
join_all
(
futures
)
.await
.into_iter
()
.collect
()
}
/// Filter a cached batch
#[instrument(skip(self))]
pub
async
fn
filter_batch
(
&
mut
self
,
batch_id
:
u64
,
keep_requests
:
Vec
<
Request
>
,
)
->
Result
<
Option
<
Batch
>>
{
let
futures
:
Vec
<
_
>
=
self
.clients
.iter_mut
()
.map
(|
client
|
Box
::
pin
(
client
.filter_batch
(
batch_id
,
keep_requests
.clone
())))
.collect
();
// all shards return the same message
join_all
(
futures
)
.await
.pop
()
.unwrap
()
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
...
...
router/src/infer.rs
View file @
ebc74d56
...
...
@@ -39,12 +39,14 @@ impl Infer {
pub
(
crate
)
fn
new
(
client
:
ShardedClient
,
validation
:
Validation
,
max_batch_size
:
usize
,
waiting_served_ratio
:
f32
,
max_batch_total_tokens
:
u32
,
max_waiting_tokens
:
usize
,
max_concurrent_requests
:
usize
,
requires_padding
:
bool
,
)
->
Self
{
// Infer shared state
let
queue
=
Queue
::
new
();
let
queue
=
Queue
::
new
(
requires_padding
);
let
shared
=
Arc
::
new
(
Shared
{
batching_task
:
Notify
::
new
(),
});
...
...
@@ -52,7 +54,8 @@ impl Infer {
// Spawn batching background task that contains all the inference logic
tokio
::
spawn
(
batching_task
(
client
,
max_batch_size
,
waiting_served_ratio
,
max_batch_total_tokens
,
max_waiting_tokens
,
queue
.clone
(),
shared
.clone
(),
...
...
@@ -232,18 +235,12 @@ impl Infer {
/// Batches requests and sends them to the inference server
async
fn
batching_task
(
mut
client
:
ShardedClient
,
max_batch_size
:
usize
,
waiting_served_ratio
:
f32
,
max_batch_total_tokens
:
u32
,
max_waiting_tokens
:
usize
,
queue
:
Queue
,
shared
:
Arc
<
Shared
>
,
)
{
// Minimum batch size after which we try to add more requests
let
limit_min_batch_size
=
if
max_batch_size
>
1
{
(
max_batch_size
/
2
)
as
u32
}
else
{
0
};
// Infinite loop
loop
{
// Wait for a notification from the Infer struct
...
...
@@ -252,7 +249,9 @@ async fn batching_task(
// Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while
let
Some
((
mut
entries
,
batch
,
span
))
=
queue
.next_batch
(
None
,
max_batch_size
)
.await
{
while
let
Some
((
mut
entries
,
batch
,
span
))
=
queue
.next_batch
(
None
,
max_batch_total_tokens
)
.await
{
let
mut
cached_batch
=
prefill
(
&
mut
client
,
batch
,
&
mut
entries
)
.instrument
(
span
)
.await
;
...
...
@@ -263,48 +262,57 @@ async fn batching_task(
while
let
Some
(
batch
)
=
cached_batch
{
// Get current batch info
let
batch_size
=
batch
.size
;
let
batch_max_tokens
=
batch
.max_tokens
;
let
mut
batches
=
vec!
[
batch
];
metrics
::
gauge!
(
"tgi_batch_current_size"
,
batch_size
as
f64
);
metrics
::
gauge!
(
"tgi_batch_current_max_tokens"
,
batch_max_tokens
as
f64
);
let
min_size
=
if
waiting_tokens
>=
max_waiting_tokens
{
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
None
}
else
{
// Minimum batch size
Some
((
batch_size
as
f32
*
waiting_served_ratio
)
.floor
()
as
usize
)
};
let
token_budget
=
max_batch_total_tokens
-
batch_max_tokens
;
// Try to get a new batch
if
let
Some
((
mut
new_entries
,
new_batch
,
span
))
=
queue
.next_batch
(
min_size
,
token_budget
)
.await
{
// Tracking metrics
if
min_size
.is_some
()
{
metrics
::
increment_counter!
(
"tgi_batch_concat"
,
"reason"
=>
"backpressure"
);
}
else
{
metrics
::
increment_counter!
(
"tgi_batch_concat"
,
"reason"
=>
"wait_exceeded"
);
}
// If the current batch is too small, we try to add more requests to it
if
batch_size
<=
limit_min_batch_size
{
let
min_size
=
match
waiting_tokens
{
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
_
if
waiting_tokens
>=
max_waiting_tokens
=>
None
,
// Minimum size criteria
_
=>
Some
(
limit_min_batch_size
as
usize
),
};
// Try to get a new batch
if
let
Some
((
mut
new_entries
,
new_batch
,
span
))
=
queue
.next_batch
(
min_size
,
max_batch_size
-
batch_size
as
usize
)
.await
{
entries
.iter_mut
()
.for_each
(|(
_
,
entry
)|
{
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let
entry_waiting_span
=
info_span!
(
parent
:
&
entry
.span
,
"waiting"
);
// Add relationships
span
.follows_from
(
&
entry_waiting_span
);
entry_waiting_span
.follows_from
(
&
span
);
// Update entry
entry
.temp_span
=
Some
(
entry_waiting_span
);
});
// Generate one token for this new batch to have the attention past in cache
let
new_cached_batch
=
prefill
(
&
mut
client
,
new_batch
,
&
mut
new_entries
)
.instrument
(
span
)
.await
;
// Reset waiting counter
waiting_tokens
=
1
;
// Extend current batch with the new batch
if
let
Some
(
new_cached_batch
)
=
new_cached_batch
{
entries
.extend
(
new_entries
);
batches
.push
(
new_cached_batch
);
}
entries
.iter_mut
()
.for_each
(|(
_
,
entry
)|
{
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let
entry_waiting_span
=
info_span!
(
parent
:
&
entry
.span
,
"waiting"
);
// Add relationships
span
.follows_from
(
&
entry_waiting_span
);
entry_waiting_span
.follows_from
(
&
span
);
// Update entry
entry
.temp_span
=
Some
(
entry_waiting_span
);
});
// Generate one token for this new batch to have the attention past in cache
let
new_cached_batch
=
prefill
(
&
mut
client
,
new_batch
,
&
mut
new_entries
)
.instrument
(
span
)
.await
;
// Reset waiting counter
waiting_tokens
=
1
;
// Extend current batch with the new batch
if
let
Some
(
new_cached_batch
)
=
new_cached_batch
{
entries
.extend
(
new_entries
);
batches
.push
(
new_cached_batch
);
}
}
// Create span for this batch to add context to inference calls
let
next_batch_size
=
entries
.len
();
let
next_batch_span
=
...
...
@@ -325,6 +333,7 @@ async fn batching_task(
waiting_tokens
+=
1
;
}
metrics
::
gauge!
(
"tgi_batch_current_size"
,
0.0
);
metrics
::
gauge!
(
"tgi_batch_current_max_tokens"
,
0.0
);
}
}
}
...
...
@@ -341,22 +350,11 @@ async fn prefill(
match
client
.prefill
(
batch
)
.await
{
Ok
((
generations
,
next_batch
))
=>
{
// Send generated tokens and filter stopped entries
filter_send_generations
(
generations
,
entries
);
// Filter next batch and remove requests that were stopped
let
next_batch
=
match
next_batch
{
None
=>
None
,
Some
(
batch
)
=>
{
let
id
=
batch
.id
;
let
next_batch
=
filter_batch
(
batch
,
entries
);
// Next batch is now empty
// Clear it from the Python shards cache
if
next_batch
.is_none
()
{
let
_
=
client
.clear_cache
(
Some
(
id
))
.await
;
}
next_batch
}
};
let
next_batch
=
filter_batch
(
client
,
next_batch
,
entries
)
.await
;
metrics
::
histogram!
(
"tgi_batch_inference_duration"
,
start_time
.elapsed
()
.as_secs_f64
(),
"method"
=>
"prefill"
);
metrics
::
increment_counter!
(
"tgi_batch_inference_success"
,
"method"
=>
"prefill"
);
...
...
@@ -384,22 +382,11 @@ async fn decode(
match
client
.decode
(
batches
)
.await
{
Ok
((
generations
,
next_batch
))
=>
{
// Send generated tokens and filter stopped entries
filter_send_generations
(
generations
,
entries
);
// Filter next batch and remove requests that were stopped
let
next_batch
=
match
next_batch
{
None
=>
None
,
Some
(
batch
)
=>
{
let
id
=
batch
.id
;
let
next_batch
=
filter_batch
(
batch
,
entries
);
// Next batch is now empty
// Clear it from the Python shards cache
if
next_batch
.is_none
()
{
let
_
=
client
.clear_cache
(
Some
(
id
))
.await
;
}
next_batch
}
};
let
next_batch
=
filter_batch
(
client
,
next_batch
,
entries
)
.await
;
metrics
::
histogram!
(
"tgi_batch_inference_duration"
,
start_time
.elapsed
()
.as_secs_f64
(),
"method"
=>
"decode"
);
metrics
::
increment_counter!
(
"tgi_batch_inference_success"
,
"method"
=>
"decode"
);
...
...
@@ -419,14 +406,35 @@ async fn decode(
/// Filter a `batch` and remove all requests not present in `entries`
#[instrument(skip_all)]
fn
filter_batch
(
mut
batch
:
Batch
,
entries
:
&
IntMap
<
u64
,
Entry
>
)
->
Option
<
Batch
>
{
async
fn
filter_batch
(
client
:
&
mut
ShardedClient
,
next_batch
:
Option
<
Batch
>
,
entries
:
&
IntMap
<
u64
,
Entry
>
,
)
->
Option
<
Batch
>
{
let
mut
batch
=
next_batch
?
;
// No need to filter
if
batch
.size
as
usize
==
entries
.len
()
{
return
Some
(
batch
);
}
let
id
=
batch
.id
;
// Retain only requests that are still in entries
batch
.requests
.retain
(|
r
|
entries
.contains_key
(
&
r
.id
));
let
size
=
batch
.requests
.len
();
if
size
==
0
{
return
None
;
if
batch
.requests
.is_empty
()
{
// All requests have been filtered out
// Next batch is now empty
// Clear it from the Python shards cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client
.clear_cache
(
Some
(
id
))
.await
.unwrap
();
None
}
else
{
// Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client
.filter_batch
(
id
,
batch
.requests
)
.await
.unwrap
()
}
batch
.size
=
size
as
u32
;
Some
(
batch
)
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
...
...
router/src/main.rs
View file @
ebc74d56
...
...
@@ -31,8 +31,12 @@ struct Args {
max_input_length
:
usize
,
#[clap(default_value
=
"1512"
,
long,
env)]
max_total_tokens
:
usize
,
#[clap(default_value
=
"32"
,
long,
env)]
max_batch_size
:
usize
,
#[clap(long,
env)]
max_batch_size
:
Option
<
usize
>
,
#[clap(default_value
=
"1.2"
,
long,
env)]
waiting_served_ratio
:
f32
,
#[clap(default_value
=
"32000"
,
long,
env)]
max_batch_total_tokens
:
u32
,
#[clap(default_value
=
"20"
,
long,
env)]
max_waiting_tokens
:
usize
,
#[clap(default_value
=
"3000"
,
long,
short,
env)]
...
...
@@ -64,6 +68,8 @@ fn main() -> Result<(), std::io::Error> {
max_input_length
,
max_total_tokens
,
max_batch_size
,
waiting_served_ratio
,
mut
max_batch_total_tokens
,
max_waiting_tokens
,
port
,
master_shard_uds_path
,
...
...
@@ -119,6 +125,12 @@ fn main() -> Result<(), std::io::Error> {
.block_on
(
async
{
init_logging
(
otlp_endpoint
,
json_output
);
if
let
Some
(
max_batch_size
)
=
max_batch_size
{
tracing
::
warn!
(
"`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead"
);
max_batch_total_tokens
=
(
max_batch_size
*
max_total_tokens
)
as
u32
;
tracing
::
warn!
(
"Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}"
);
}
if
tokenizer
.is_none
()
{
tracing
::
warn!
(
"Could not find a fast tokenizer implementation for {tokenizer_name}"
...
...
@@ -174,7 +186,8 @@ fn main() -> Result<(), std::io::Error> {
max_stop_sequences
,
max_input_length
,
max_total_tokens
,
max_batch_size
,
waiting_served_ratio
,
max_batch_total_tokens
,
max_waiting_tokens
,
sharded_client
,
tokenizer
,
...
...
router/src/queue.rs
View file @
ebc74d56
...
...
@@ -2,7 +2,6 @@ use crate::infer::InferError;
use
crate
::
infer
::
InferStreamResponse
;
use
crate
::
validation
::
ValidGenerateRequest
;
use
nohash_hasher
::{
BuildNoHashHasher
,
IntMap
};
use
std
::
cmp
::
min
;
use
std
::
collections
::
VecDeque
;
use
text_generation_client
::{
Batch
,
Request
};
use
tokio
::
sync
::
oneshot
;
...
...
@@ -34,12 +33,12 @@ pub(crate) struct Queue {
}
impl
Queue
{
pub
(
crate
)
fn
new
()
->
Self
{
pub
(
crate
)
fn
new
(
requires_padding
:
bool
)
->
Self
{
// Create channel
let
(
queue_sender
,
queue_receiver
)
=
flume
::
unbounded
();
// Launch background queue task
tokio
::
spawn
(
queue_task
(
queue_receiver
));
tokio
::
spawn
(
queue_task
(
requires_padding
,
queue_receiver
));
Self
{
queue_sender
}
}
...
...
@@ -59,7 +58,7 @@ impl Queue {
pub
(
crate
)
async
fn
next_batch
(
&
self
,
min_size
:
Option
<
usize
>
,
max_size
:
usize
,
token_budget
:
u32
,
)
->
Option
<
NextBatch
>
{
// Create response channel
let
(
response_sender
,
response_receiver
)
=
oneshot
::
channel
();
...
...
@@ -68,7 +67,7 @@ impl Queue {
self
.queue_sender
.send
(
QueueCommand
::
NextBatch
{
min_size
,
max_size
,
token_budget
,
response_sender
,
span
:
Span
::
current
(),
})
...
...
@@ -80,20 +79,24 @@ impl Queue {
}
// Background task responsible of the queue state
async
fn
queue_task
(
receiver
:
flume
::
Receiver
<
QueueCommand
>
)
{
let
mut
state
=
State
::
new
();
async
fn
queue_task
(
requires_padding
:
bool
,
receiver
:
flume
::
Receiver
<
QueueCommand
>
)
{
let
mut
state
=
State
::
new
(
requires_padding
);
while
let
Ok
(
cmd
)
=
receiver
.recv_async
()
.await
{
match
cmd
{
QueueCommand
::
Append
(
entry
,
span
)
=>
span
.in_scope
(||
state
.append
(
entry
)),
QueueCommand
::
Append
(
entry
,
span
)
=>
{
span
.in_scope
(||
state
.append
(
entry
));
metrics
::
increment_gauge!
(
"tgi_queue_size"
,
1.0
);
}
QueueCommand
::
NextBatch
{
min_size
,
max_size
,
token_budget
,
response_sender
,
span
,
}
=>
span
.in_scope
(||
{
let
next_batch
=
state
.next_batch
(
min_size
,
max_size
);
let
next_batch
=
state
.next_batch
(
min_size
,
token_budget
);
response_sender
.send
(
next_batch
)
.unwrap_or
(());
metrics
::
gauge!
(
"tgi_queue_size"
,
state
.entries
.len
()
as
f64
);
}),
}
}
...
...
@@ -110,14 +113,18 @@ struct State {
/// Id of the next batch
next_batch_id
:
u64
,
/// Whether the model is using padding
requires_padding
:
bool
,
}
impl
State
{
fn
new
()
->
Self
{
fn
new
(
requires_padding
:
bool
)
->
Self
{
Self
{
entries
:
VecDeque
::
with_capacity
(
128
),
next_id
:
0
,
next_batch_id
:
0
,
requires_padding
,
}
}
...
...
@@ -130,11 +137,10 @@ impl State {
// Push entry in the queue
self
.entries
.push_back
((
self
.next_id
,
entry
));
self
.next_id
+=
1
;
metrics
::
increment_gauge!
(
"tgi_queue_size"
,
1.0
);
}
// Get the next batch
fn
next_batch
(
&
mut
self
,
min_size
:
Option
<
usize
>
,
max_size
:
usize
)
->
Option
<
NextBatch
>
{
fn
next_batch
(
&
mut
self
,
min_size
:
Option
<
usize
>
,
token_budget
:
u32
)
->
Option
<
NextBatch
>
{
if
self
.entries
.is_empty
()
{
return
None
;
}
...
...
@@ -146,17 +152,19 @@ impl State {
}
}
let
max_batch_size
=
min
(
self
.entries
.len
(),
max_size
);
// Create span for this batch to add context to inference calls
let
next_batch_span
=
info_span!
(
parent
:
None
,
"batch"
,
batch_size
=
tracing
::
field
::
Empty
);
next_batch_span
.follows_from
(
&
Span
::
current
());
let
mut
batch_requests
=
Vec
::
with_capacity
(
max_batch_size
);
let
mut
batch_requests
=
Vec
::
with_capacity
(
self
.entries
.len
()
);
let
mut
batch_entries
=
IntMap
::
with_capacity_and_hasher
(
max_batch_size
,
BuildNoHashHasher
::
default
());
IntMap
::
with_capacity_and_hasher
(
self
.entries
.len
()
,
BuildNoHashHasher
::
default
());
// Iterate on buffer
let
mut
max_input_length
=
0
;
let
mut
prefill_tokens
:
u32
=
0
;
let
mut
decode_tokens
:
u32
=
0
;
// Pop entries starting from the front of the queue
while
let
Some
((
id
,
mut
entry
))
=
self
.entries
.pop_front
()
{
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
...
...
@@ -165,6 +173,24 @@ impl State {
continue
;
}
if
self
.requires_padding
{
// We pad to max input length in the Python shards
// We need to take these padding tokens into the equation
max_input_length
=
max_input_length
.max
(
entry
.request.input_length
);
prefill_tokens
=
(
batch_requests
.len
()
+
1
)
as
u32
*
max_input_length
}
else
{
prefill_tokens
+=
entry
.request.input_length
;
}
decode_tokens
+=
entry
.request.stopping_parameters.max_new_tokens
;
if
(
prefill_tokens
+
decode_tokens
)
>
token_budget
{
// Entry is over budget
// Add it back to the front
self
.entries
.push_front
((
id
,
entry
));
break
;
}
// Create a new span to link the batch back to this entry
let
entry_batch_span
=
info_span!
(
parent
:
&
entry
.span
,
"infer"
);
// Add relationships
...
...
@@ -184,21 +210,29 @@ impl State {
entry
.batch_time
=
Some
(
Instant
::
now
());
// Insert in batch_entries IntMap
batch_entries
.insert
(
id
,
entry
);
if
batch_requests
.len
()
==
max_batch_size
{
// We have enough requests in the batch
break
;
}
}
metrics
::
gauge!
(
"tgi_queue_size"
,
self
.entries
.len
()
as
f64
);
// Maybe all entries were dropped because their channel were closed
// Empty batch
if
batch_requests
.is_empty
()
{
return
None
;
}
// Final batch size once we dropped entries
// Check if our batch is big enough
if
let
Some
(
min_size
)
=
min_size
{
// Batch is too small
if
batch_requests
.len
()
<
min_size
{
// Add back entries to the queue in the correct order
for
r
in
batch_requests
.into_iter
()
.rev
()
{
let
id
=
r
.id
;
let
entry
=
batch_entries
.remove
(
&
id
)
.unwrap
();
self
.entries
.push_front
((
id
,
entry
));
}
return
None
;
}
}
// Final batch size
let
size
=
batch_requests
.len
()
as
u32
;
next_batch_span
.record
(
"batch_size"
,
size
);
...
...
@@ -206,11 +240,13 @@ impl State {
id
:
self
.next_batch_id
,
requests
:
batch_requests
,
size
,
max_tokens
:
(
prefill_tokens
+
decode_tokens
),
};
// Increment batch id
self
.next_batch_id
+=
1
;
metrics
::
histogram!
(
"tgi_batch_next_size"
,
batch
.size
as
f64
);
Some
((
batch_entries
,
batch
,
next_batch_span
))
}
}
...
...
@@ -222,7 +258,7 @@ enum QueueCommand {
Append
(
Entry
,
Span
),
NextBatch
{
min_size
:
Option
<
usize
>
,
max_size
:
usize
,
token_budget
:
u32
,
response_sender
:
oneshot
::
Sender
<
Option
<
NextBatch
>>
,
span
:
Span
,
},
...
...
@@ -243,6 +279,7 @@ mod tests {
let
entry
=
Entry
{
request
:
ValidGenerateRequest
{
inputs
:
""
.to_string
(),
input_length
:
0
,
truncate
:
0
,
parameters
:
NextTokenChooserParameters
{
temperature
:
0.0
,
...
...
@@ -256,7 +293,7 @@ mod tests {
},
stopping_parameters
:
StoppingCriteriaParameters
{
ignore_eos_token
:
false
,
max_new_tokens
:
0
,
max_new_tokens
:
1
,
stop_sequences
:
vec!
[],
},
},
...
...
@@ -271,7 +308,7 @@ mod tests {
#[test]
fn
test_append
()
{
let
mut
state
=
State
::
new
();
let
mut
state
=
State
::
new
(
false
);
let
(
entry
,
_
guard
)
=
default_entry
();
assert_eq!
(
state
.next_id
,
0
);
...
...
@@ -287,7 +324,7 @@ mod tests {
#[test]
fn
test_next_batch_empty
()
{
let
mut
state
=
State
::
new
();
let
mut
state
=
State
::
new
(
false
);
assert
!
(
state
.next_batch
(
None
,
1
)
.is_none
());
assert
!
(
state
.next_batch
(
Some
(
1
),
1
)
.is_none
());
...
...
@@ -295,7 +332,7 @@ mod tests {
#[test]
fn
test_next_batch_min_size
()
{
let
mut
state
=
State
::
new
();
let
mut
state
=
State
::
new
(
false
);
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
state
.append
(
entry1
);
...
...
@@ -326,8 +363,8 @@ mod tests {
}
#[test]
fn
test_next_batch_
max_size
()
{
let
mut
state
=
State
::
new
();
fn
test_next_batch_
token_budget
()
{
let
mut
state
=
State
::
new
(
false
);
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
state
.append
(
entry1
);
...
...
@@ -360,14 +397,14 @@ mod tests {
#[tokio::test]
async
fn
test_queue_append
()
{
let
queue
=
Queue
::
new
();
let
queue
=
Queue
::
new
(
false
);
let
(
entry
,
_
guard
)
=
default_entry
();
queue
.append
(
entry
);
}
#[tokio::test]
async
fn
test_queue_next_batch_empty
()
{
let
queue
=
Queue
::
new
();
let
queue
=
Queue
::
new
(
false
);
assert
!
(
queue
.next_batch
(
None
,
1
)
.await
.is_none
());
assert
!
(
queue
.next_batch
(
Some
(
1
),
1
)
.await
.is_none
());
...
...
@@ -375,7 +412,7 @@ mod tests {
#[tokio::test]
async
fn
test_queue_next_batch_min_size
()
{
let
queue
=
Queue
::
new
();
let
queue
=
Queue
::
new
(
false
);
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
queue
.append
(
entry1
);
...
...
@@ -397,8 +434,8 @@ mod tests {
}
#[tokio::test]
async
fn
test_queue_next_batch_
max_size
()
{
let
queue
=
Queue
::
new
();
async
fn
test_queue_next_batch_
token_budget
()
{
let
queue
=
Queue
::
new
(
false
);
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
queue
.append
(
entry1
);
...
...
@@ -423,7 +460,7 @@ mod tests {
#[tokio::test]
async
fn
test_queue_next_batch_dropped_receiver
()
{
let
queue
=
Queue
::
new
();
let
queue
=
Queue
::
new
(
false
);
let
(
entry
,
_
)
=
default_entry
();
queue
.append
(
entry
);
...
...
router/src/server.rs
View file @
ebc74d56
...
...
@@ -511,7 +511,8 @@ pub async fn run(
max_stop_sequences
:
usize
,
max_input_length
:
usize
,
max_total_tokens
:
usize
,
max_batch_size
:
usize
,
waiting_served_ratio
:
f32
,
max_batch_total_tokens
:
u32
,
max_waiting_tokens
:
usize
,
client
:
ShardedClient
,
tokenizer
:
Option
<
Tokenizer
>
,
...
...
@@ -571,9 +572,11 @@ pub async fn run(
let
infer
=
Infer
::
new
(
client
,
validation
,
max_batch_size
,
waiting_served_ratio
,
max_batch_total_tokens
,
max_waiting_tokens
,
max_concurrent_requests
,
shard_info
.requires_padding
,
);
// Duration buckets
...
...
@@ -604,7 +607,7 @@ pub async fn run(
.collect
();
// Batch size buckets
let
batch_size_matcher
=
Matcher
::
Full
(
String
::
from
(
"tgi_batch_next_size"
));
let
batch_size_buckets
:
Vec
<
f64
>
=
(
0
..
max_batch_size
)
.map
(|
x
|
(
x
+
1
)
as
f64
)
.collect
();
let
batch_size_buckets
:
Vec
<
f64
>
=
(
0
..
1024
)
.map
(|
x
|
(
x
+
1
)
as
f64
)
.collect
();
// Prometheus handler
let
builder
=
PrometheusBuilder
::
new
()
...
...
router/src/validation.rs
View file @
ebc74d56
...
...
@@ -69,7 +69,7 @@ impl Validation {
inputs
:
String
,
truncate
:
Option
<
usize
>
,
max_new_tokens
:
u32
,
)
->
Result
<
String
,
ValidationError
>
{
)
->
Result
<
(
String
,
usize
),
ValidationError
>
{
// If we have a fast tokenizer
if
let
Some
(
sender
)
=
&
self
.sender
{
// Create response channel
...
...
@@ -105,25 +105,24 @@ impl Validation {
}
metrics
::
histogram!
(
"tgi_request_input_length"
,
input_length
as
f64
);
Ok
(
inputs
)
Ok
(
(
inputs
,
input_length
)
)
}
// Return inputs without validation
else
{
// In this case, we don't know the real length in tokens of the inputs
// However, the inputs will be truncated by the python servers
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
let
input_length
=
truncate
.unwrap_or
(
self
.max_input_length
);
// Validate MaxNewTokens
if
(
truncate
.unwrap_or
(
self
.max_input_length
)
as
u32
+
max_new_tokens
)
>
self
.max_total_tokens
as
u32
{
if
(
input_length
as
u32
+
max_new_tokens
)
>
self
.max_total_tokens
as
u32
{
return
Err
(
ValidationError
::
MaxNewTokens
(
self
.max_total_tokens
-
self
.max_input_length
,
max_new_tokens
,
));
}
Ok
(
inputs
)
Ok
(
(
inputs
,
input_length
)
)
}
}
...
...
@@ -238,7 +237,7 @@ impl Validation {
.unwrap_or
(
Ok
(
None
))
?
;
// Validate inputs
let
inputs
=
self
let
(
inputs
,
input_length
)
=
self
.validate_input
(
request
.inputs
,
truncate
,
max_new_tokens
)
.await
?
;
...
...
@@ -262,6 +261,7 @@ impl Validation {
Ok
(
ValidGenerateRequest
{
inputs
,
input_length
:
input_length
as
u32
,
truncate
:
truncate
.unwrap_or
(
self
.max_input_length
)
as
u32
,
parameters
,
stopping_parameters
,
...
...
@@ -333,6 +333,7 @@ type TokenizerRequest = (
#[derive(Debug)]
pub
(
crate
)
struct
ValidGenerateRequest
{
pub
inputs
:
String
,
pub
input_length
:
u32
,
pub
truncate
:
u32
,
pub
parameters
:
NextTokenChooserParameters
,
pub
stopping_parameters
:
StoppingCriteriaParameters
,
...
...
server/tests/models/test_bloom.py
View file @
ebc74d56
...
...
@@ -181,9 +181,7 @@ def test_causal_lm_generate_token_completion_multi(
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
0
]])
for
_
in
range
(
stopping_criterias
[
0
].
max_new_tokens
-
stopping_criterias
[
1
].
max_new_tokens
-
1
stopping_criterias
[
0
].
max_new_tokens
-
stopping_criterias
[
1
].
max_new_tokens
-
1
):
generations
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
...
...
server/tests/models/test_causal_lm.py
View file @
ebc74d56
...
...
@@ -174,14 +174,14 @@ def test_causal_lm_generate_token_completion_multi(
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
)
# Copy stopping_criterias before filtering
stopping_criterias
=
default_multi_requests_causal_lm_batch
.
stopping_criterias
.
copy
()
stopping_criterias
=
(
default_multi_requests_causal_lm_batch
.
stopping_criterias
.
copy
()
)
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
0
]])
for
_
in
range
(
stopping_criterias
[
0
].
max_new_tokens
-
stopping_criterias
[
1
].
max_new_tokens
-
1
stopping_criterias
[
0
].
max_new_tokens
-
stopping_criterias
[
1
].
max_new_tokens
-
1
):
generations
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
...
...
server/text_generation_server/models/causal_lm.py
View file @
ebc74d56
...
...
@@ -46,6 +46,9 @@ class CausalLMBatch(Batch):
max_input_length
:
int
padding_right_offset
:
int
# Maximum number of tokens this batch will grow to
max_tokens
:
int
# Past metadata
keys_head_dim_last
:
bool
=
True
...
...
@@ -54,6 +57,7 @@ class CausalLMBatch(Batch):
id
=
self
.
batch_id
,
requests
=
self
.
requests
,
size
=
len
(
self
),
max_tokens
=
self
.
max_tokens
,
)
@
classmethod
...
...
@@ -73,6 +77,7 @@ class CausalLMBatch(Batch):
# Parse batch
max_truncation
=
0
padding_right_offset
=
0
max_decode_tokens
=
0
for
i
,
r
in
enumerate
(
pb
.
requests
):
requests_idx_mapping
[
r
.
id
]
=
i
inputs
.
append
(
r
.
inputs
)
...
...
@@ -84,6 +89,7 @@ class CausalLMBatch(Batch):
)
stopping_criterias
.
append
(
stopping_criteria
)
max_truncation
=
max
(
max_truncation
,
r
.
truncate
)
max_decode_tokens
+=
stopping_criteria
.
max_new_tokens
padding_right_offset
=
max
(
padding_right_offset
,
stopping_criteria
.
max_new_tokens
)
...
...
@@ -112,6 +118,8 @@ class CausalLMBatch(Batch):
position_ids
.
masked_fill_
(
tokenized_inputs
[
"attention_mask"
]
==
0
,
1
)
all_input_ids
=
tokenized_inputs
[
"input_ids"
].
T
.
split
(
1
,
dim
=
1
)
max_tokens
=
len
(
inputs
)
*
max_input_length
+
max_decode_tokens
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
...
...
@@ -128,6 +136,7 @@ class CausalLMBatch(Batch):
stopping_criterias
=
stopping_criterias
,
max_input_length
=
max_input_length
.
item
(),
padding_right_offset
=
padding_right_offset
,
max_tokens
=
max_tokens
,
)
@
tracer
.
start_as_current_span
(
"filter"
)
...
...
@@ -150,6 +159,7 @@ class CausalLMBatch(Batch):
next_token_choosers
=
[]
stopping_criterias
=
[]
total_remaining_decode_tokens
=
0
new_padding_right_offset
=
0
for
i
,
r
in
enumerate
(
requests
):
...
...
@@ -168,19 +178,23 @@ class CausalLMBatch(Batch):
next_token_choosers
.
append
(
self
.
next_token_choosers
[
idx
])
stopping_criteria
=
self
.
stopping_criterias
[
idx
]
stopping_criterias
.
append
(
stopping_criteria
)
new_padding_right_offset
=
max
(
new_padding_right_offset
,
remaining_decode_tokens
=
(
stopping_criteria
.
max_new_tokens
-
stopping_criteria
.
current_tokens
)
total_remaining_decode_tokens
+=
remaining_decode_tokens
new_padding_right_offset
=
max
(
new_padding_right_offset
,
remaining_decode_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids
=
self
.
input_ids
[
keep_indices
]
position_ids
=
self
.
position_ids
[
keep_indices
]
self
.
attention_mask
=
self
.
attention_mask
[
keep_indices
,
-
(
self
.
padding_right_offset
+
max_input_length
):
(
self
.
attention_mask
.
shape
[
1
]
-
self
.
padding_right_offset
)
+
new_padding_right_offset
,
-
(
self
.
padding_right_offset
+
max_input_length
)
:
(
self
.
attention_mask
.
shape
[
1
]
-
self
.
padding_right_offset
)
+
new_padding_right_offset
,
]
# Ensure that past_key_values tensors can be updated in-place
...
...
@@ -203,6 +217,8 @@ class CausalLMBatch(Batch):
layer
[
1
]
=
past_values
[
keep_indices
,
:,
-
past_kv_length
:,
:]
del
past_values
max_tokens
=
len
(
requests
)
*
max_input_length
+
total_remaining_decode_tokens
self
.
requests
=
requests
self
.
requests_idx_mapping
=
requests_idx_mapping
self
.
input_ids
=
input_ids
...
...
@@ -215,6 +231,7 @@ class CausalLMBatch(Batch):
self
.
stopping_criterias
=
stopping_criterias
self
.
max_input_length
=
max_input_length
self
.
padding_right_offset
=
new_padding_right_offset
self
.
max_tokens
=
max_tokens
return
self
...
...
@@ -239,6 +256,7 @@ class CausalLMBatch(Batch):
all_input_ids
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
max_tokens
=
0
# Batch tensors
input_ids
=
None
...
...
@@ -314,7 +332,8 @@ class CausalLMBatch(Batch):
# And ensure that we can update tensors in-place
if
type
(
batch
.
past_key_values
[
0
])
==
tuple
:
batch
.
past_key_values
=
[
[
t
.
view
(
len
(
batch
),
-
1
,
*
t
.
shape
[
-
2
:])
for
t
in
layer
]
for
layer
in
batch
.
past_key_values
[
t
.
view
(
len
(
batch
),
-
1
,
*
t
.
shape
[
-
2
:])
for
t
in
layer
]
for
layer
in
batch
.
past_key_values
]
elif
batch
.
past_key_values
[
0
][
0
].
shape
==
3
:
for
layer
in
batch
.
past_key_values
:
...
...
@@ -322,6 +341,10 @@ class CausalLMBatch(Batch):
layer
[
k
]
=
t
.
view
(
len
(
batch
),
-
1
,
*
t
.
shape
[
-
2
:])
start_index
=
end_index
# Add eventual padding tokens that were added while concatenating
max_tokens
+=
batch
.
max_tokens
+
(
max_input_length
-
batch
.
max_input_length
)
*
len
(
batch
)
first_past_kvs
=
batches
[
0
].
past_key_values
_
,
num_heads
,
padded_sequence_length
,
head_dim
=
first_past_kvs
[
0
][
1
].
shape
...
...
@@ -371,7 +394,9 @@ class CausalLMBatch(Batch):
start_index
=
end_index
padded_past_values
=
first_past_kvs
[
j
][
1
].
new_zeros
(
padded_past_values_shape
)
padded_past_values
=
first_past_kvs
[
j
][
1
].
new_zeros
(
padded_past_values_shape
)
start_index
=
0
for
batch
in
batches
:
past_values
=
batch
.
past_key_values
[
j
][
1
]
...
...
@@ -387,6 +412,7 @@ class CausalLMBatch(Batch):
]
=
past_values
[:,
:,
-
past_seq_len
:,
:]
del
past_values
# Update values
start_index
=
end_index
past_key_values
.
append
([
padded_past_keys
,
padded_past_values
])
...
...
@@ -408,6 +434,7 @@ class CausalLMBatch(Batch):
max_input_length
=
max_input_length
,
padding_right_offset
=
padding_right_offset
,
keys_head_dim_last
=
batches
[
0
].
keys_head_dim_last
,
max_tokens
=
max_tokens
,
)
def
__len__
(
self
):
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
ebc74d56
...
...
@@ -56,9 +56,15 @@ class FlashCausalLMBatch(Batch):
# Constant shared tensor, ref here just so that it's accessible in concatentate()
past_pad
:
Optional
[
torch
.
Tensor
]
# Maximum number of tokens this batch will grow to
max_tokens
:
int
def
to_pb
(
self
)
->
generate_pb2
.
Batch
:
return
generate_pb2
.
Batch
(
id
=
self
.
batch_id
,
requests
=
self
.
requests
,
size
=
len
(
self
)
id
=
self
.
batch_id
,
requests
=
self
.
requests
,
size
=
len
(
self
),
max_tokens
=
self
.
max_tokens
,
)
@
classmethod
...
...
@@ -86,6 +92,8 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_length
=
0
max_tokens
=
0
# Parse batch
for
i
,
r
in
enumerate
(
pb
.
requests
):
# request id -> idx in list mapping
...
...
@@ -115,16 +123,20 @@ class FlashCausalLMBatch(Batch):
cu_seqlens
.
append
(
cumulative_length
+
input_length
)
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
max_new_tokens
=
stopping_criteria
.
max_new_tokens
stopping_criterias
.
append
(
stopping_criteria
)
all_input_ids_tensor
.
append
(
F
.
pad
(
tokenized_input
,
(
0
,
stopping_criteria
.
max_new_tokens
))
)
# Update
cumulative_length
+=
input_length
max_tokens
+=
input_length
+
max_new_tokens
return
cls
(
batch_id
=
pb
.
id
,
...
...
@@ -143,6 +155,7 @@ class FlashCausalLMBatch(Batch):
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
past_pad
=
None
,
max_tokens
=
max_tokens
,
)
@
tracer
.
start_as_current_span
(
"filter"
)
...
...
@@ -177,6 +190,8 @@ class FlashCausalLMBatch(Batch):
next_token_choosers
=
[]
stopping_criterias
=
[]
max_tokens
=
0
for
i
,
r
in
enumerate
(
requests
):
idx
=
self
.
requests_idx_mapping
[
r
.
id
]
requests_idx_mapping
[
r
.
id
]
=
i
...
...
@@ -203,9 +218,14 @@ class FlashCausalLMBatch(Batch):
token_offsets
.
append
(
self
.
token_offsets
[
idx
])
next_token_choosers
.
append
(
self
.
next_token_choosers
[
idx
])
stopping_criterias
.
append
(
self
.
stopping_criterias
[
idx
])
stopping_criteria
=
self
.
stopping_criterias
[
idx
]
stopping_criterias
.
append
(
stopping_criteria
)
cumulative_length
+=
request_input_length
max_tokens
+=
request_input_length
+
(
stopping_criteria
.
max_new_tokens
-
stopping_criteria
.
current_tokens
)
if
single_request
:
# Preallocate tensor for bs = 1 case
...
...
@@ -241,6 +261,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor
=
all_input_ids_tensor
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
max_tokens
=
max_tokens
,
)
@
classmethod
...
...
@@ -269,6 +290,7 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_batch_size
=
0
cumulative_length
=
0
max_tokens
=
0
for
i
,
batch
in
enumerate
(
batches
):
requests
.
extend
(
batch
.
requests
)
...
...
@@ -310,6 +332,7 @@ class FlashCausalLMBatch(Batch):
# Update
cumulative_length
+=
batch
.
cu_seqlens
[
-
1
]
cumulative_batch_size
+=
len
(
batch
)
max_tokens
+=
batch
.
max_tokens
return
FlashCausalLMBatch
(
batch_id
=
batches
[
0
].
batch_id
,
...
...
@@ -328,6 +351,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor
=
all_input_ids_tensor
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
max_tokens
=
max_tokens
,
)
def
__len__
(
self
):
...
...
server/text_generation_server/models/galactica.py
View file @
ebc74d56
...
...
@@ -101,6 +101,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
# Parse batch
max_truncation
=
0
padding_right_offset
=
0
max_decode_tokens
=
0
for
i
,
r
in
enumerate
(
pb
.
requests
):
requests_idx_mapping
[
r
.
id
]
=
i
# Add escape_custom_split_sequence to the CausalLMBatch logic
...
...
@@ -113,6 +114,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
)
stopping_criterias
.
append
(
stopping_criteria
)
max_truncation
=
max
(
max_truncation
,
r
.
truncate
)
max_decode_tokens
+=
stopping_criteria
.
max_new_tokens
padding_right_offset
=
max
(
padding_right_offset
,
stopping_criteria
.
max_new_tokens
)
...
...
@@ -141,6 +143,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
position_ids
.
masked_fill_
(
tokenized_inputs
[
"attention_mask"
]
==
0
,
1
)
all_input_ids
=
tokenized_inputs
[
"input_ids"
].
T
.
split
(
1
,
dim
=
1
)
max_tokens
=
len
(
inputs
)
*
max_input_length
+
max_decode_tokens
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
...
...
@@ -157,6 +161,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
stopping_criterias
=
stopping_criterias
,
max_input_length
=
max_input_length
.
item
(),
padding_right_offset
=
padding_right_offset
,
max_tokens
=
max_tokens
,
)
...
...
server/text_generation_server/models/seq2seq_lm.py
View file @
ebc74d56
...
...
@@ -54,10 +54,16 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length
:
int
padding_right_offset
:
int
# Maximum number of tokens this batch will grow to
max_tokens
:
int
def
to_pb
(
self
)
->
generate_pb2
.
Batch
:
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
return
generate_pb2
.
Batch
(
id
=
self
.
batch_id
,
requests
=
self
.
requests
,
size
=
len
(
self
)
id
=
self
.
batch_id
,
requests
=
self
.
requests
,
size
=
len
(
self
),
max_tokens
=
self
.
max_tokens
,
)
@
classmethod
...
...
@@ -80,6 +86,7 @@ class Seq2SeqLMBatch(Batch):
# Parse batch
max_truncation
=
0
padding_right_offset
=
0
max_decode_tokens
=
0
for
i
,
r
in
enumerate
(
pb
.
requests
):
inputs
.
append
(
r
.
inputs
)
requests_idx_mapping
[
r
.
id
]
=
i
...
...
@@ -92,6 +99,7 @@ class Seq2SeqLMBatch(Batch):
)
stopping_criterias
.
append
(
stopping_criteria
)
max_truncation
=
max
(
max_truncation
,
r
.
truncate
)
max_decode_tokens
+=
stopping_criteria
.
max_new_tokens
padding_right_offset
=
max
(
padding_right_offset
,
stopping_criteria
.
max_new_tokens
)
...
...
@@ -117,6 +125,8 @@ class Seq2SeqLMBatch(Batch):
)
all_decoder_input_ids
=
decoder_input_ids
.
view
(
-
1
).
split
(
1
)
max_tokens
=
len
(
inputs
)
*
max_input_length
+
max_decode_tokens
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
...
...
@@ -137,6 +147,7 @@ class Seq2SeqLMBatch(Batch):
max_input_length
=
max_input_length
.
item
(),
max_decoder_input_length
=
1
,
padding_right_offset
=
padding_right_offset
,
max_tokens
=
max_tokens
,
)
@
tracer
.
start_as_current_span
(
"filter"
)
...
...
@@ -166,6 +177,8 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length
=
0
padding_right_offset
=
0
remaining_decode_tokens
=
0
for
i
,
r
in
enumerate
(
requests
):
idx
=
self
.
requests_idx_mapping
[
r
.
id
]
requests_idx_mapping
[
r
.
id
]
=
i
...
...
@@ -187,27 +200,38 @@ class Seq2SeqLMBatch(Batch):
)
padding_right_offset
=
max
(
padding_right_offset
,
self
.
stopping_criterias
[
idx
].
max_new_tokens
-
self
.
stopping_criterias
[
idx
].
current_tokens
self
.
stopping_criterias
[
idx
].
max_new_tokens
-
self
.
stopping_criterias
[
idx
].
current_tokens
,
)
next_token_choosers
.
append
(
self
.
next_token_choosers
[
idx
])
stopping_criterias
.
append
(
self
.
stopping_criterias
[
idx
])
stopping_criteria
=
self
.
stopping_criterias
[
idx
]
stopping_criterias
.
append
(
stopping_criteria
)
remaining_decode_tokens
+=
(
stopping_criteria
.
max_new_tokens
-
stopping_criteria
.
current_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
self
.
decoder_input_ids
=
self
.
decoder_input_ids
[
keep_indices
]
self
.
attention_mask
=
self
.
attention_mask
[
keep_indices
,
-
max_input_length
:]
if
self
.
decoder_attention_mask
is
not
None
:
self
.
decoder_attention_mask
=
self
.
decoder_attention_mask
[
keep_indices
,
-
(
self
.
padding_right_offset
+
max_decoder_input_length
):
(
self
.
decoder_attention_mask
.
shape
[
1
]
-
self
.
padding_right_offset
)
+
padding_right_offset
,
keep_indices
,
-
(
self
.
padding_right_offset
+
max_decoder_input_length
)
:
(
self
.
decoder_attention_mask
.
shape
[
1
]
-
self
.
padding_right_offset
)
+
padding_right_offset
,
]
self
.
encoder_last_hidden_state
=
self
.
encoder_last_hidden_state
[
keep_indices
,
-
max_input_length
:]
self
.
encoder_last_hidden_state
=
self
.
encoder_last_hidden_state
[
keep_indices
,
-
max_input_length
:
]
# Ensure that past_key_values tensors can be updated in-place
if
type
(
self
.
past_key_values
[
0
])
==
tuple
:
self
.
past_key_values
=
[[
t
for
t
in
layer
]
for
layer
in
self
.
past_key_values
]
self
.
past_key_values
=
[
[
t
for
t
in
layer
]
for
layer
in
self
.
past_key_values
]
decoder_past_seq_len
=
max_decoder_input_length
-
1
for
layer
in
self
.
past_key_values
:
...
...
@@ -216,6 +240,11 @@ class Seq2SeqLMBatch(Batch):
layer
[
2
]
=
layer
[
2
][
keep_indices
,
:,
-
max_input_length
:]
layer
[
3
]
=
layer
[
3
][
keep_indices
,
:,
-
max_input_length
:]
max_tokens
=
(
len
(
requests
)
*
(
max_input_length
+
max_decoder_input_length
)
+
remaining_decode_tokens
)
self
.
requests
=
requests
self
.
requests_idx_mapping
=
requests_idx_mapping
self
.
input_ids
=
None
...
...
@@ -229,10 +258,10 @@ class Seq2SeqLMBatch(Batch):
self
.
max_input_length
=
max_input_length
self
.
max_decoder_input_length
=
max_decoder_input_length
self
.
padding_right_offset
=
padding_right_offset
self
.
max_tokens
=
max_tokens
return
self
@
classmethod
@
tracer
.
start_as_current_span
(
"concatenate"
)
def
concatenate
(
cls
,
batches
:
List
[
"Seq2SeqLMBatch"
])
->
"Seq2SeqLMBatch"
:
...
...
@@ -261,6 +290,7 @@ class Seq2SeqLMBatch(Batch):
token_offsets
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
max_tokens
=
0
# Batch tensors
attention_mask
=
None
...
...
@@ -363,9 +393,18 @@ class Seq2SeqLMBatch(Batch):
# Ensure that we can update tensors in-place
if
type
(
batch
.
past_key_values
[
0
])
==
tuple
:
batch
.
past_key_values
=
[[
t
for
t
in
layer
]
for
layer
in
batch
.
past_key_values
]
batch
.
past_key_values
=
[
[
t
for
t
in
layer
]
for
layer
in
batch
.
past_key_values
]
start_index
=
end_index
# Add eventual padding tokens that were added while concatenating
max_tokens
+=
batch
.
max_tokens
+
(
max_input_length
-
batch
.
max_input_length
+
max_decoder_input_length
-
batch
.
max_decoder_input_length
)
*
len
(
batch
)
# Determine shapes for new past kv tensors
first_past_kvs
=
batches
[
0
].
past_key_values
...
...
@@ -404,9 +443,9 @@ class Seq2SeqLMBatch(Batch):
end_index
=
start_index
+
len
(
batch
)
# We slice the past keys and values to remove the padding from previous batches
past_seq_len
=
batch
.
max_decoder_input_length
-
1
padded_past_values
[
start_index
:
end_index
,
:,
-
past_seq_len
:,
:
]
=
t
[:,
:,
-
past_seq_len
:,
:]
padded_past_values
[
start_index
:
end_index
,
:,
-
past_seq_len
:,
:]
=
t
[
:
,
:,
-
past_seq_len
:,
:
]
del
t
start_index
=
end_index
...
...
@@ -426,8 +465,8 @@ class Seq2SeqLMBatch(Batch):
end_index
=
start_index
+
len
(
batch
)
# We slice the past keys and values to remove the padding from previous batches
padded_past_values
[
start_index
:
end_index
,
:,
-
batch
.
max_input_length
:,
:
]
=
t
[:,
:,
-
batch
.
max_input_length
:,
:]
start_index
:
end_index
,
:,
-
batch
.
max_input_length
:,
:
]
=
t
[:,
:,
-
batch
.
max_input_length
:,
:]
del
t
start_index
=
end_index
...
...
@@ -452,6 +491,7 @@ class Seq2SeqLMBatch(Batch):
max_input_length
=
max_input_length
,
max_decoder_input_length
=
max_decoder_input_length
,
padding_right_offset
=
padding_right_offset
,
max_tokens
=
max_tokens
,
)
def
__len__
(
self
):
...
...
server/text_generation_server/server.py
View file @
ebc74d56
...
...
@@ -41,6 +41,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
torch
.
cuda
.
empty_cache
()
return
generate_pb2
.
ClearCacheResponse
()
async
def
FilterBatch
(
self
,
request
,
context
):
batch
=
self
.
cache
.
pop
(
request
.
batch_id
)
if
batch
is
None
:
raise
ValueError
(
f
"Batch ID
{
request
.
batch_id
}
not found in cache."
)
filtered_batch
=
batch
.
filter
(
request
.
keep_requests
)
self
.
cache
.
set
(
filtered_batch
)
return
generate_pb2
.
FilterBatchResponse
(
batch
=
filtered_batch
.
to_pb
())
async
def
Prefill
(
self
,
request
,
context
):
batch
=
self
.
model
.
batch_type
.
from_pb
(
request
.
batch
,
self
.
model
.
tokenizer
,
self
.
model
.
device
...
...
@@ -63,9 +72,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch
=
self
.
cache
.
pop
(
batch_pb
.
id
)
if
batch
is
None
:
raise
ValueError
(
f
"Batch ID
{
batch_pb
.
id
}
not found in cache."
)
batch
=
batch
.
filter
(
batch_pb
.
requests
)
if
batch
is
not
None
:
batches
.
append
(
batch
)
batches
.
append
(
batch
)
if
len
(
batches
)
==
0
:
raise
ValueError
(
"All batches are empty"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment