Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
53214633
Unverified
Commit
53214633
authored
Feb 09, 2024
by
OlivierDehaene
Committed by
GitHub
Feb 09, 2024
Browse files
feat(router): add max_batch_size (#1542)
Some hardware require a maximum batch size.
parent
a4e58016
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
118 additions
and
21 deletions
+118
-21
docs/source/basic_tutorials/launcher.md
docs/source/basic_tutorials/launcher.md
+8
-0
launcher/src/main.rs
launcher/src/main.rs
+11
-0
router/client/src/client.rs
router/client/src/client.rs
+6
-0
router/client/src/sharded_client.rs
router/client/src/sharded_client.rs
+7
-1
router/src/infer.rs
router/src/infer.rs
+11
-2
router/src/lib.rs
router/src/lib.rs
+2
-0
router/src/main.rs
router/src/main.rs
+5
-0
router/src/queue.rs
router/src/queue.rs
+65
-18
router/src/server.rs
router/src/server.rs
+3
-0
No files found.
docs/source/basic_tutorials/launcher.md
View file @
53214633
...
@@ -197,6 +197,14 @@ Options:
...
@@ -197,6 +197,14 @@ Options:
[
env
:
MAX_WAITING_TOKENS
=]
[
env
:
MAX_WAITING_TOKENS
=]
[
default: 20]
[
default: 20]
```
## MAX_BATCH_SIZE
```
shell
--max-batch-size
<MAX_BATCH_SIZE>
Enforce a maximum number of requests per batch Specific flag
for
hardware targets that
do
not support unpadded inference
[
env
:
MAX_BATCH_SIZE
=]
```
```
## HOSTNAME
## HOSTNAME
```
shell
```
shell
...
...
launcher/src/main.rs
View file @
53214633
...
@@ -279,6 +279,11 @@ struct Args {
...
@@ -279,6 +279,11 @@ struct Args {
#[clap(default_value
=
"20"
,
long,
env)]
#[clap(default_value
=
"20"
,
long,
env)]
max_waiting_tokens
:
usize
,
max_waiting_tokens
:
usize
,
/// Enforce a maximum number of requests per batch
/// Specific flag for hardware targets that do not support unpadded inference
#[clap(long,
env)]
max_batch_size
:
Option
<
usize
>
,
/// The IP address to listen on
/// The IP address to listen on
#[clap(default_value
=
"0.0.0.0"
,
long,
env)]
#[clap(default_value
=
"0.0.0.0"
,
long,
env)]
hostname
:
String
,
hostname
:
String
,
...
@@ -1046,6 +1051,12 @@ fn spawn_webserver(
...
@@ -1046,6 +1051,12 @@ fn spawn_webserver(
router_args
.push
(
max_batch_total_tokens
.to_string
());
router_args
.push
(
max_batch_total_tokens
.to_string
());
}
}
// Router optional max batch size
if
let
Some
(
max_batch_size
)
=
args
.max_batch_size
{
router_args
.push
(
"--max-batch-size"
.to_string
());
router_args
.push
(
max_batch_size
.to_string
());
}
// Model optional revision
// Model optional revision
if
let
Some
(
ref
revision
)
=
args
.revision
{
if
let
Some
(
ref
revision
)
=
args
.revision
{
router_args
.push
(
"--revision"
.to_string
());
router_args
.push
(
"--revision"
.to_string
());
...
...
router/client/src/client.rs
View file @
53214633
...
@@ -105,6 +105,7 @@ impl Client {
...
@@ -105,6 +105,7 @@ impl Client {
max_input_length
:
u32
,
max_input_length
:
u32
,
max_prefill_tokens
:
u32
,
max_prefill_tokens
:
u32
,
max_total_tokens
:
u32
,
max_total_tokens
:
u32
,
max_batch_size
:
Option
<
usize
>
,
)
->
Result
<
Option
<
u32
>>
{
)
->
Result
<
Option
<
u32
>>
{
let
mut
n_tokens
=
0
;
let
mut
n_tokens
=
0
;
let
mut
requests
=
Vec
::
new
();
let
mut
requests
=
Vec
::
new
();
...
@@ -137,6 +138,11 @@ impl Client {
...
@@ -137,6 +138,11 @@ impl Client {
top_n_tokens
:
20
,
top_n_tokens
:
20
,
});
});
n_tokens
+=
max_input_length
;
n_tokens
+=
max_input_length
;
// Check max_batch_size
if
Some
(
requests
.len
())
==
max_batch_size
{
break
;
}
}
}
let
batch
=
Batch
{
let
batch
=
Batch
{
...
...
router/client/src/sharded_client.rs
View file @
53214633
...
@@ -97,12 +97,18 @@ impl ShardedClient {
...
@@ -97,12 +97,18 @@ impl ShardedClient {
max_input_length
:
u32
,
max_input_length
:
u32
,
max_prefill_tokens
:
u32
,
max_prefill_tokens
:
u32
,
max_total_tokens
:
u32
,
max_total_tokens
:
u32
,
max_batch_size
:
Option
<
usize
>
,
)
->
Result
<
Option
<
u32
>>
{
)
->
Result
<
Option
<
u32
>>
{
let
futures
:
Vec
<
_
>
=
self
let
futures
:
Vec
<
_
>
=
self
.clients
.clients
.iter_mut
()
.iter_mut
()
.map
(|
client
|
{
.map
(|
client
|
{
Box
::
pin
(
client
.warmup
(
max_input_length
,
max_prefill_tokens
,
max_total_tokens
))
Box
::
pin
(
client
.warmup
(
max_input_length
,
max_prefill_tokens
,
max_total_tokens
,
max_batch_size
,
))
})
})
.collect
();
.collect
();
// Take the minimum value
// Take the minimum value
...
...
router/src/infer.rs
View file @
53214633
...
@@ -61,6 +61,7 @@ impl Infer {
...
@@ -61,6 +61,7 @@ impl Infer {
max_batch_prefill_tokens
:
u32
,
max_batch_prefill_tokens
:
u32
,
max_batch_total_tokens
:
u32
,
max_batch_total_tokens
:
u32
,
max_waiting_tokens
:
usize
,
max_waiting_tokens
:
usize
,
max_batch_size
:
Option
<
usize
>
,
max_concurrent_requests
:
usize
,
max_concurrent_requests
:
usize
,
requires_padding
:
bool
,
requires_padding
:
bool
,
window_size
:
Option
<
u32
>
,
window_size
:
Option
<
u32
>
,
...
@@ -81,6 +82,7 @@ impl Infer {
...
@@ -81,6 +82,7 @@ impl Infer {
max_batch_prefill_tokens
,
max_batch_prefill_tokens
,
max_batch_total_tokens
,
max_batch_total_tokens
,
max_waiting_tokens
,
max_waiting_tokens
,
max_batch_size
,
queue
.clone
(),
queue
.clone
(),
shared
.clone
(),
shared
.clone
(),
generation_health
,
generation_health
,
...
@@ -338,6 +340,7 @@ async fn batching_task(
...
@@ -338,6 +340,7 @@ async fn batching_task(
max_batch_prefill_tokens
:
u32
,
max_batch_prefill_tokens
:
u32
,
max_batch_total_tokens
:
u32
,
max_batch_total_tokens
:
u32
,
max_waiting_tokens
:
usize
,
max_waiting_tokens
:
usize
,
max_batch_size
:
Option
<
usize
>
,
queue
:
Queue
,
queue
:
Queue
,
shared
:
Arc
<
Shared
>
,
shared
:
Arc
<
Shared
>
,
generation_health
:
Arc
<
AtomicBool
>
,
generation_health
:
Arc
<
AtomicBool
>
,
...
@@ -351,7 +354,12 @@ async fn batching_task(
...
@@ -351,7 +354,12 @@ async fn batching_task(
// This batch might be smaller than the maximum batch size if there are not enough requests
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
// waiting in the queue
while
let
Some
((
mut
entries
,
batch
,
span
))
=
queue
while
let
Some
((
mut
entries
,
batch
,
span
))
=
queue
.next_batch
(
None
,
max_batch_prefill_tokens
,
max_batch_total_tokens
)
.next_batch
(
None
,
max_batch_size
,
max_batch_prefill_tokens
,
max_batch_total_tokens
,
)
.await
.await
{
{
let
mut
cached_batch
=
prefill
(
&
mut
client
,
batch
,
&
mut
entries
,
&
generation_health
)
let
mut
cached_batch
=
prefill
(
&
mut
client
,
batch
,
&
mut
entries
,
&
generation_health
)
...
@@ -379,10 +387,11 @@ async fn batching_task(
...
@@ -379,10 +387,11 @@ async fn batching_task(
};
};
let
token_budget
=
max_batch_total_tokens
.saturating_sub
(
batch_max_tokens
);
let
token_budget
=
max_batch_total_tokens
.saturating_sub
(
batch_max_tokens
);
let
max_size
=
max_batch_size
.map
(|
max_size
|
max_size
-
batch_size
as
usize
);
// Try to get a new batch
// Try to get a new batch
if
let
Some
((
mut
new_entries
,
new_batch
,
span
))
=
queue
if
let
Some
((
mut
new_entries
,
new_batch
,
span
))
=
queue
.next_batch
(
min_size
,
max_batch_prefill_tokens
,
token_budget
)
.next_batch
(
min_size
,
max_size
,
max_batch_prefill_tokens
,
token_budget
)
.await
.await
{
{
// Tracking metrics
// Tracking metrics
...
...
router/src/lib.rs
View file @
53214633
...
@@ -73,6 +73,8 @@ pub struct Info {
...
@@ -73,6 +73,8 @@ pub struct Info {
pub
max_batch_total_tokens
:
u32
,
pub
max_batch_total_tokens
:
u32
,
#[schema(example
=
"20"
)]
#[schema(example
=
"20"
)]
pub
max_waiting_tokens
:
usize
,
pub
max_waiting_tokens
:
usize
,
#[schema(nullable
=
true
,
example
=
"null"
)]
pub
max_batch_size
:
Option
<
usize
>
,
#[schema(example
=
"2"
)]
#[schema(example
=
"2"
)]
pub
validation_workers
:
usize
,
pub
validation_workers
:
usize
,
/// Router Info
/// Router Info
...
...
router/src/main.rs
View file @
53214633
...
@@ -45,6 +45,8 @@ struct Args {
...
@@ -45,6 +45,8 @@ struct Args {
max_batch_total_tokens
:
Option
<
u32
>
,
max_batch_total_tokens
:
Option
<
u32
>
,
#[clap(default_value
=
"20"
,
long,
env)]
#[clap(default_value
=
"20"
,
long,
env)]
max_waiting_tokens
:
usize
,
max_waiting_tokens
:
usize
,
#[clap(long,
env)]
max_batch_size
:
Option
<
usize
>
,
#[clap(default_value
=
"0.0.0.0"
,
long,
env)]
#[clap(default_value
=
"0.0.0.0"
,
long,
env)]
hostname
:
String
,
hostname
:
String
,
#[clap(default_value
=
"3000"
,
long,
short,
env)]
#[clap(default_value
=
"3000"
,
long,
short,
env)]
...
@@ -91,6 +93,7 @@ async fn main() -> Result<(), RouterError> {
...
@@ -91,6 +93,7 @@ async fn main() -> Result<(), RouterError> {
max_batch_prefill_tokens
,
max_batch_prefill_tokens
,
max_batch_total_tokens
,
max_batch_total_tokens
,
max_waiting_tokens
,
max_waiting_tokens
,
max_batch_size
,
hostname
,
hostname
,
port
,
port
,
master_shard_uds_path
,
master_shard_uds_path
,
...
@@ -288,6 +291,7 @@ async fn main() -> Result<(), RouterError> {
...
@@ -288,6 +291,7 @@ async fn main() -> Result<(), RouterError> {
max_input_length
as
u32
,
max_input_length
as
u32
,
max_batch_prefill_tokens
,
max_batch_prefill_tokens
,
max_total_tokens
as
u32
,
max_total_tokens
as
u32
,
max_batch_size
,
)
)
.await
.await
.map_err
(
RouterError
::
Warmup
)
?
.map_err
(
RouterError
::
Warmup
)
?
...
@@ -344,6 +348,7 @@ async fn main() -> Result<(), RouterError> {
...
@@ -344,6 +348,7 @@ async fn main() -> Result<(), RouterError> {
max_batch_prefill_tokens
,
max_batch_prefill_tokens
,
max_supported_batch_total_tokens
,
max_supported_batch_total_tokens
,
max_waiting_tokens
,
max_waiting_tokens
,
max_batch_size
,
sharded_client
,
sharded_client
,
tokenizer
,
tokenizer
,
validation_workers
,
validation_workers
,
...
...
router/src/queue.rs
View file @
53214633
...
@@ -70,6 +70,7 @@ impl Queue {
...
@@ -70,6 +70,7 @@ impl Queue {
pub
(
crate
)
async
fn
next_batch
(
pub
(
crate
)
async
fn
next_batch
(
&
self
,
&
self
,
min_size
:
Option
<
usize
>
,
min_size
:
Option
<
usize
>
,
max_size
:
Option
<
usize
>
,
prefill_token_budget
:
u32
,
prefill_token_budget
:
u32
,
token_budget
:
u32
,
token_budget
:
u32
,
)
->
Option
<
NextBatch
>
{
)
->
Option
<
NextBatch
>
{
...
@@ -80,6 +81,7 @@ impl Queue {
...
@@ -80,6 +81,7 @@ impl Queue {
self
.queue_sender
self
.queue_sender
.send
(
QueueCommand
::
NextBatch
{
.send
(
QueueCommand
::
NextBatch
{
min_size
,
min_size
,
max_size
,
prefill_token_budget
,
prefill_token_budget
,
token_budget
,
token_budget
,
response_sender
,
response_sender
,
...
@@ -110,12 +112,14 @@ async fn queue_task(
...
@@ -110,12 +112,14 @@ async fn queue_task(
}
}
QueueCommand
::
NextBatch
{
QueueCommand
::
NextBatch
{
min_size
,
min_size
,
max_size
,
prefill_token_budget
,
prefill_token_budget
,
token_budget
,
token_budget
,
response_sender
,
response_sender
,
span
,
span
,
}
=>
span
.in_scope
(||
{
}
=>
span
.in_scope
(||
{
let
next_batch
=
state
.next_batch
(
min_size
,
prefill_token_budget
,
token_budget
);
let
next_batch
=
state
.next_batch
(
min_size
,
max_size
,
prefill_token_budget
,
token_budget
);
response_sender
.send
(
next_batch
)
.unwrap
();
response_sender
.send
(
next_batch
)
.unwrap
();
metrics
::
gauge!
(
"tgi_queue_size"
,
state
.entries
.len
()
as
f64
);
metrics
::
gauge!
(
"tgi_queue_size"
,
state
.entries
.len
()
as
f64
);
}),
}),
...
@@ -181,6 +185,7 @@ impl State {
...
@@ -181,6 +185,7 @@ impl State {
fn
next_batch
(
fn
next_batch
(
&
mut
self
,
&
mut
self
,
min_size
:
Option
<
usize
>
,
min_size
:
Option
<
usize
>
,
max_size
:
Option
<
usize
>
,
prefill_token_budget
:
u32
,
prefill_token_budget
:
u32
,
token_budget
:
u32
,
token_budget
:
u32
,
)
->
Option
<
NextBatch
>
{
)
->
Option
<
NextBatch
>
{
...
@@ -274,6 +279,11 @@ impl State {
...
@@ -274,6 +279,11 @@ impl State {
entry
.batch_time
=
Some
(
Instant
::
now
());
entry
.batch_time
=
Some
(
Instant
::
now
());
// Insert in batch_entries IntMap
// Insert in batch_entries IntMap
batch_entries
.insert
(
id
,
entry
);
batch_entries
.insert
(
id
,
entry
);
// Check if max_size
if
Some
(
batch_requests
.len
())
==
max_size
{
break
;
}
}
}
// Empty batch
// Empty batch
...
@@ -322,6 +332,7 @@ enum QueueCommand {
...
@@ -322,6 +332,7 @@ enum QueueCommand {
Append
(
Box
<
Entry
>
,
Span
),
Append
(
Box
<
Entry
>
,
Span
),
NextBatch
{
NextBatch
{
min_size
:
Option
<
usize
>
,
min_size
:
Option
<
usize
>
,
max_size
:
Option
<
usize
>
,
prefill_token_budget
:
u32
,
prefill_token_budget
:
u32
,
token_budget
:
u32
,
token_budget
:
u32
,
response_sender
:
oneshot
::
Sender
<
Option
<
NextBatch
>>
,
response_sender
:
oneshot
::
Sender
<
Option
<
NextBatch
>>
,
...
@@ -394,8 +405,8 @@ mod tests {
...
@@ -394,8 +405,8 @@ mod tests {
fn
test_next_batch_empty
()
{
fn
test_next_batch_empty
()
{
let
mut
state
=
State
::
new
(
false
,
1
,
None
,
0
);
let
mut
state
=
State
::
new
(
false
,
1
,
None
,
0
);
assert
!
(
state
.next_batch
(
None
,
1
,
1
)
.is_none
());
assert
!
(
state
.next_batch
(
None
,
None
,
1
,
1
)
.is_none
());
assert
!
(
state
.next_batch
(
Some
(
1
),
1
,
1
)
.is_none
());
assert
!
(
state
.next_batch
(
Some
(
1
),
None
,
1
,
1
)
.is_none
());
}
}
#[test]
#[test]
...
@@ -406,7 +417,7 @@ mod tests {
...
@@ -406,7 +417,7 @@ mod tests {
state
.append
(
entry1
);
state
.append
(
entry1
);
state
.append
(
entry2
);
state
.append
(
entry2
);
let
(
entries
,
batch
,
_
)
=
state
.next_batch
(
None
,
2
,
2
)
.unwrap
();
let
(
entries
,
batch
,
_
)
=
state
.next_batch
(
None
,
None
,
2
,
2
)
.unwrap
();
assert_eq!
(
entries
.len
(),
2
);
assert_eq!
(
entries
.len
(),
2
);
assert
!
(
entries
.contains_key
(
&
0
));
assert
!
(
entries
.contains_key
(
&
0
));
assert
!
(
entries
.contains_key
(
&
1
));
assert
!
(
entries
.contains_key
(
&
1
));
...
@@ -422,7 +433,7 @@ mod tests {
...
@@ -422,7 +433,7 @@ mod tests {
let
(
entry3
,
_
guard3
)
=
default_entry
();
let
(
entry3
,
_
guard3
)
=
default_entry
();
state
.append
(
entry3
);
state
.append
(
entry3
);
assert
!
(
state
.next_batch
(
Some
(
2
),
2
,
2
)
.is_none
());
assert
!
(
state
.next_batch
(
Some
(
2
),
None
,
2
,
2
)
.is_none
());
assert_eq!
(
state
.next_id
,
3
);
assert_eq!
(
state
.next_id
,
3
);
assert_eq!
(
state
.entries
.len
(),
1
);
assert_eq!
(
state
.entries
.len
(),
1
);
...
@@ -430,6 +441,26 @@ mod tests {
...
@@ -430,6 +441,26 @@ mod tests {
assert_eq!
(
id
,
2
);
assert_eq!
(
id
,
2
);
}
}
#[test]
fn
test_next_batch_max_size
()
{
let
mut
state
=
State
::
new
(
false
,
1
,
None
,
0
);
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
state
.append
(
entry1
);
state
.append
(
entry2
);
let
(
entries
,
batch
,
_
)
=
state
.next_batch
(
None
,
Some
(
1
),
2
,
2
)
.unwrap
();
assert_eq!
(
entries
.len
(),
1
);
assert
!
(
entries
.contains_key
(
&
0
));
assert
!
(
entries
.get
(
&
0
)
.unwrap
()
.batch_time
.is_some
());
assert_eq!
(
batch
.id
,
0
);
assert_eq!
(
batch
.size
,
1
);
assert_eq!
(
state
.next_id
,
2
);
assert_eq!
(
state
.entries
.len
(),
1
);
assert_eq!
(
state
.next_batch_id
,
1
);
}
#[test]
#[test]
fn
test_next_batch_token_budget
()
{
fn
test_next_batch_token_budget
()
{
let
mut
state
=
State
::
new
(
false
,
1
,
None
,
0
);
let
mut
state
=
State
::
new
(
false
,
1
,
None
,
0
);
...
@@ -438,7 +469,7 @@ mod tests {
...
@@ -438,7 +469,7 @@ mod tests {
state
.append
(
entry1
);
state
.append
(
entry1
);
state
.append
(
entry2
);
state
.append
(
entry2
);
let
(
entries
,
batch
,
_
)
=
state
.next_batch
(
None
,
1
,
1
)
.unwrap
();
let
(
entries
,
batch
,
_
)
=
state
.next_batch
(
None
,
None
,
1
,
1
)
.unwrap
();
assert_eq!
(
entries
.len
(),
1
);
assert_eq!
(
entries
.len
(),
1
);
assert
!
(
entries
.contains_key
(
&
0
));
assert
!
(
entries
.contains_key
(
&
0
));
assert_eq!
(
batch
.id
,
0
);
assert_eq!
(
batch
.id
,
0
);
...
@@ -451,7 +482,7 @@ mod tests {
...
@@ -451,7 +482,7 @@ mod tests {
let
(
entry3
,
_
guard3
)
=
default_entry
();
let
(
entry3
,
_
guard3
)
=
default_entry
();
state
.append
(
entry3
);
state
.append
(
entry3
);
let
(
entries
,
batch
,
_
)
=
state
.next_batch
(
None
,
3
,
3
)
.unwrap
();
let
(
entries
,
batch
,
_
)
=
state
.next_batch
(
None
,
None
,
3
,
3
)
.unwrap
();
assert_eq!
(
entries
.len
(),
2
);
assert_eq!
(
entries
.len
(),
2
);
assert
!
(
entries
.contains_key
(
&
1
));
assert
!
(
entries
.contains_key
(
&
1
));
assert
!
(
entries
.contains_key
(
&
2
));
assert
!
(
entries
.contains_key
(
&
2
));
...
@@ -474,8 +505,8 @@ mod tests {
...
@@ -474,8 +505,8 @@ mod tests {
async
fn
test_queue_next_batch_empty
()
{
async
fn
test_queue_next_batch_empty
()
{
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
0
);
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
0
);
assert
!
(
queue
.next_batch
(
None
,
1
,
1
)
.await
.is_none
());
assert
!
(
queue
.next_batch
(
None
,
None
,
1
,
1
)
.await
.is_none
());
assert
!
(
queue
.next_batch
(
Some
(
1
),
1
,
1
)
.await
.is_none
());
assert
!
(
queue
.next_batch
(
Some
(
1
),
None
,
1
,
1
)
.await
.is_none
());
}
}
#[tokio::test]
#[tokio::test]
...
@@ -486,7 +517,7 @@ mod tests {
...
@@ -486,7 +517,7 @@ mod tests {
queue
.append
(
entry1
);
queue
.append
(
entry1
);
queue
.append
(
entry2
);
queue
.append
(
entry2
);
let
(
entries
,
batch
,
_
)
=
queue
.next_batch
(
None
,
2
,
2
)
.await
.unwrap
();
let
(
entries
,
batch
,
_
)
=
queue
.next_batch
(
None
,
None
,
2
,
2
)
.await
.unwrap
();
assert_eq!
(
entries
.len
(),
2
);
assert_eq!
(
entries
.len
(),
2
);
assert
!
(
entries
.contains_key
(
&
0
));
assert
!
(
entries
.contains_key
(
&
0
));
assert
!
(
entries
.contains_key
(
&
1
));
assert
!
(
entries
.contains_key
(
&
1
));
...
@@ -499,11 +530,11 @@ mod tests {
...
@@ -499,11 +530,11 @@ mod tests {
queue
.append
(
entry3
);
queue
.append
(
entry3
);
// Not enough requests pending
// Not enough requests pending
assert
!
(
queue
.next_batch
(
Some
(
2
),
2
,
2
)
.await
.is_none
());
assert
!
(
queue
.next_batch
(
Some
(
2
),
None
,
2
,
2
)
.await
.is_none
());
// Not enough token budget
// Not enough token budget
assert
!
(
queue
.next_batch
(
Some
(
1
),
0
,
0
)
.await
.is_none
());
assert
!
(
queue
.next_batch
(
Some
(
1
),
None
,
0
,
0
)
.await
.is_none
());
// Ok
// Ok
let
(
entries2
,
batch2
,
_
)
=
queue
.next_batch
(
Some
(
1
),
2
,
2
)
.await
.unwrap
();
let
(
entries2
,
batch2
,
_
)
=
queue
.next_batch
(
Some
(
1
),
None
,
2
,
2
)
.await
.unwrap
();
assert_eq!
(
entries2
.len
(),
1
);
assert_eq!
(
entries2
.len
(),
1
);
assert
!
(
entries2
.contains_key
(
&
2
));
assert
!
(
entries2
.contains_key
(
&
2
));
assert
!
(
entries2
.get
(
&
2
)
.unwrap
()
.batch_time
.is_some
());
assert
!
(
entries2
.get
(
&
2
)
.unwrap
()
.batch_time
.is_some
());
...
@@ -511,6 +542,22 @@ mod tests {
...
@@ -511,6 +542,22 @@ mod tests {
assert_eq!
(
batch2
.size
,
1
);
assert_eq!
(
batch2
.size
,
1
);
}
}
#[tokio::test]
async
fn
test_queue_next_batch_max_size
()
{
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
0
);
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
queue
.append
(
entry1
);
queue
.append
(
entry2
);
let
(
entries
,
batch
,
_
)
=
queue
.next_batch
(
None
,
Some
(
1
),
2
,
2
)
.await
.unwrap
();
assert_eq!
(
entries
.len
(),
1
);
assert
!
(
entries
.contains_key
(
&
0
));
assert
!
(
entries
.get
(
&
0
)
.unwrap
()
.batch_time
.is_some
());
assert_eq!
(
batch
.id
,
0
);
assert_eq!
(
batch
.size
,
1
);
}
#[tokio::test]
#[tokio::test]
async
fn
test_queue_next_batch_token_budget
()
{
async
fn
test_queue_next_batch_token_budget
()
{
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
0
);
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
0
);
...
@@ -519,7 +566,7 @@ mod tests {
...
@@ -519,7 +566,7 @@ mod tests {
queue
.append
(
entry1
);
queue
.append
(
entry1
);
queue
.append
(
entry2
);
queue
.append
(
entry2
);
let
(
entries
,
batch
,
_
)
=
queue
.next_batch
(
None
,
1
,
1
)
.await
.unwrap
();
let
(
entries
,
batch
,
_
)
=
queue
.next_batch
(
None
,
None
,
1
,
1
)
.await
.unwrap
();
assert_eq!
(
entries
.len
(),
1
);
assert_eq!
(
entries
.len
(),
1
);
assert
!
(
entries
.contains_key
(
&
0
));
assert
!
(
entries
.contains_key
(
&
0
));
assert_eq!
(
batch
.id
,
0
);
assert_eq!
(
batch
.id
,
0
);
...
@@ -528,7 +575,7 @@ mod tests {
...
@@ -528,7 +575,7 @@ mod tests {
let
(
entry3
,
_
guard3
)
=
default_entry
();
let
(
entry3
,
_
guard3
)
=
default_entry
();
queue
.append
(
entry3
);
queue
.append
(
entry3
);
let
(
entries
,
batch
,
_
)
=
queue
.next_batch
(
None
,
3
,
3
)
.await
.unwrap
();
let
(
entries
,
batch
,
_
)
=
queue
.next_batch
(
None
,
None
,
3
,
3
)
.await
.unwrap
();
assert_eq!
(
entries
.len
(),
2
);
assert_eq!
(
entries
.len
(),
2
);
assert
!
(
entries
.contains_key
(
&
1
));
assert
!
(
entries
.contains_key
(
&
1
));
assert
!
(
entries
.contains_key
(
&
2
));
assert
!
(
entries
.contains_key
(
&
2
));
...
@@ -545,9 +592,9 @@ mod tests {
...
@@ -545,9 +592,9 @@ mod tests {
queue
.append
(
entry2
);
queue
.append
(
entry2
);
// Budget of 1 is not enough
// Budget of 1 is not enough
assert
!
(
queue
.next_batch
(
None
,
1
,
1
)
.await
.is_none
());
assert
!
(
queue
.next_batch
(
None
,
None
,
1
,
1
)
.await
.is_none
());
let
(
entries
,
batch
,
_
)
=
queue
.next_batch
(
None
,
6
,
6
)
.await
.unwrap
();
let
(
entries
,
batch
,
_
)
=
queue
.next_batch
(
None
,
None
,
6
,
6
)
.await
.unwrap
();
assert_eq!
(
entries
.len
(),
2
);
assert_eq!
(
entries
.len
(),
2
);
assert
!
(
entries
.contains_key
(
&
0
));
assert
!
(
entries
.contains_key
(
&
0
));
assert
!
(
entries
.contains_key
(
&
1
));
assert
!
(
entries
.contains_key
(
&
1
));
...
@@ -561,6 +608,6 @@ mod tests {
...
@@ -561,6 +608,6 @@ mod tests {
let
(
entry
,
_
)
=
default_entry
();
let
(
entry
,
_
)
=
default_entry
();
queue
.append
(
entry
);
queue
.append
(
entry
);
assert
!
(
queue
.next_batch
(
None
,
1
,
1
)
.await
.is_none
());
assert
!
(
queue
.next_batch
(
None
,
None
,
1
,
1
)
.await
.is_none
());
}
}
}
}
router/src/server.rs
View file @
53214633
...
@@ -768,6 +768,7 @@ pub async fn run(
...
@@ -768,6 +768,7 @@ pub async fn run(
max_batch_prefill_tokens
:
u32
,
max_batch_prefill_tokens
:
u32
,
max_batch_total_tokens
:
u32
,
max_batch_total_tokens
:
u32
,
max_waiting_tokens
:
usize
,
max_waiting_tokens
:
usize
,
max_batch_size
:
Option
<
usize
>
,
client
:
ShardedClient
,
client
:
ShardedClient
,
tokenizer
:
Option
<
Tokenizer
>
,
tokenizer
:
Option
<
Tokenizer
>
,
validation_workers
:
usize
,
validation_workers
:
usize
,
...
@@ -849,6 +850,7 @@ pub async fn run(
...
@@ -849,6 +850,7 @@ pub async fn run(
max_batch_prefill_tokens
,
max_batch_prefill_tokens
,
max_batch_total_tokens
,
max_batch_total_tokens
,
max_waiting_tokens
,
max_waiting_tokens
,
max_batch_size
,
max_concurrent_requests
,
max_concurrent_requests
,
shard_info
.requires_padding
,
shard_info
.requires_padding
,
shard_info
.window_size
,
shard_info
.window_size
,
...
@@ -930,6 +932,7 @@ pub async fn run(
...
@@ -930,6 +932,7 @@ pub async fn run(
waiting_served_ratio
,
waiting_served_ratio
,
max_batch_total_tokens
,
max_batch_total_tokens
,
max_waiting_tokens
,
max_waiting_tokens
,
max_batch_size
,
validation_workers
,
validation_workers
,
version
:
env!
(
"CARGO_PKG_VERSION"
),
version
:
env!
(
"CARGO_PKG_VERSION"
),
sha
:
option_env!
(
"VERGEN_GIT_SHA"
),
sha
:
option_env!
(
"VERGEN_GIT_SHA"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment