Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
8aece3bd
Unverified
Commit
8aece3bd
authored
Jun 05, 2024
by
OlivierDehaene
Committed by
GitHub
Jun 05, 2024
Browse files
feat: move allocation logic to rust (#1835)
Close #2007
parent
9ffe1f1e
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
333 additions
and
240 deletions
+333
-240
Cargo.toml
Cargo.toml
+6
-1
Dockerfile
Dockerfile
+5
-5
Dockerfile_amd
Dockerfile_amd
+5
-5
Dockerfile_intel
Dockerfile_intel
+5
-5
benchmark/src/generation.rs
benchmark/src/generation.rs
+3
-0
proto/v3/generate.proto
proto/v3/generate.proto
+6
-0
router/client/src/v3/client.rs
router/client/src/v3/client.rs
+5
-1
router/client/src/v3/sharded_client.rs
router/client/src/v3/sharded_client.rs
+4
-0
router/src/infer/v3/block_allocator.rs
router/src/infer/v3/block_allocator.rs
+136
-0
router/src/infer/v3/mod.rs
router/src/infer/v3/mod.rs
+1
-0
router/src/infer/v3/queue.rs
router/src/infer/v3/queue.rs
+142
-82
router/src/infer/v3/scheduler.rs
router/src/infer/v3/scheduler.rs
+8
-1
server/text_generation_server/models/cache_manager.py
server/text_generation_server/models/cache_manager.py
+0
-140
server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py
...on_server/models/custom_modeling/flash_cohere_modeling.py
+1
-0
server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py
...tion_server/models/custom_modeling/flash_dbrx_modeling.py
+1
-0
server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
...ion_server/models/custom_modeling/flash_gemma_modeling.py
+1
-0
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
...tion_server/models/custom_modeling/flash_neox_modeling.py
+1
-0
server/text_generation_server/models/custom_modeling/flash_phi_modeling.py
...ation_server/models/custom_modeling/flash_phi_modeling.py
+1
-0
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
...ration_server/models/custom_modeling/flash_rw_modeling.py
+1
-0
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+1
-0
No files found.
Cargo.toml
View file @
8aece3bd
...
@@ -26,7 +26,12 @@ incremental = true
...
@@ -26,7 +26,12 @@ incremental = true
inherits
=
"release"
inherits
=
"release"
debug
=
1
debug
=
1
incremental
=
true
incremental
=
true
panic
=
"abort"
[profile.release-opt]
inherits
=
"release"
debug
=
0
incremental
=
false
lto
=
"fat"
lto
=
"fat"
opt-level
=
3
opt-level
=
3
codegen-units
=
1
codegen-units
=
1
panic
=
"abort"
Dockerfile
View file @
8aece3bd
...
@@ -25,7 +25,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
...
@@ -25,7 +25,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
rm
-f
$PROTOC_ZIP
rm
-f
$PROTOC_ZIP
COPY
--from=planner /usr/src/recipe.json recipe.json
COPY
--from=planner /usr/src/recipe.json recipe.json
RUN
cargo chef cook
--release
--recipe-path
recipe.json
RUN
cargo chef cook
--
profile
release
-opt
--recipe-path
recipe.json
COPY
Cargo.toml Cargo.toml
COPY
Cargo.toml Cargo.toml
COPY
rust-toolchain.toml rust-toolchain.toml
COPY
rust-toolchain.toml rust-toolchain.toml
...
@@ -33,7 +33,7 @@ COPY proto proto
...
@@ -33,7 +33,7 @@ COPY proto proto
COPY
benchmark benchmark
COPY
benchmark benchmark
COPY
router router
COPY
router router
COPY
launcher launcher
COPY
launcher launcher
RUN
cargo build
--release
RUN
cargo build
--
profile
release
-opt
# Python builder
# Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
...
@@ -226,11 +226,11 @@ RUN cd server && \
...
@@ -226,11 +226,11 @@ RUN cd server && \
pip
install
".[bnb, accelerate, quantize, peft, outlines]"
--no-cache-dir
pip
install
".[bnb, accelerate, quantize, peft, outlines]"
--no-cache-dir
# Install benchmarker
# Install benchmarker
COPY
--from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
COPY
--from=builder /usr/src/target/release
-opt
/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
# Install router
COPY
--from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
COPY
--from=builder /usr/src/target/release
-opt
/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
# Install launcher
COPY
--from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
COPY
--from=builder /usr/src/target/release
-opt
/text-generation-launcher /usr/local/bin/text-generation-launcher
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--no-install-recommends
\
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--no-install-recommends
\
build-essential
\
build-essential
\
...
...
Dockerfile_amd
View file @
8aece3bd
...
@@ -25,7 +25,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
...
@@ -25,7 +25,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
rm -f $PROTOC_ZIP
rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json
COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json
RUN cargo chef cook --
profile
release
-opt
--recipe-path recipe.json
COPY Cargo.toml Cargo.toml
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY rust-toolchain.toml rust-toolchain.toml
...
@@ -33,7 +33,7 @@ COPY proto proto
...
@@ -33,7 +33,7 @@ COPY proto proto
COPY benchmark benchmark
COPY benchmark benchmark
COPY router router
COPY router router
COPY launcher launcher
COPY launcher launcher
RUN cargo build --release
RUN cargo build --
profile
release
-opt
# Text Generation Inference base image for RoCm
# Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base
...
@@ -193,11 +193,11 @@ RUN cd server && \
...
@@ -193,11 +193,11 @@ RUN cd server && \
pip install ".[accelerate, peft, outlines]" --no-cache-dir
pip install ".[accelerate, peft, outlines]" --no-cache-dir
# Install benchmarker
# Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
COPY --from=builder /usr/src/target/release
-opt
/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
# Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
COPY --from=builder /usr/src/target/release
-opt
/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
# Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
COPY --from=builder /usr/src/target/release
-opt
/text-generation-launcher /usr/local/bin/text-generation-launcher
# AWS Sagemaker compatible image
# AWS Sagemaker compatible image
FROM base as sagemaker
FROM base as sagemaker
...
...
Dockerfile_intel
View file @
8aece3bd
...
@@ -24,7 +24,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
...
@@ -24,7 +24,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
rm -f $PROTOC_ZIP
rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json
COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json
RUN cargo chef cook --
profile
release
-opt
--recipe-path recipe.json
COPY Cargo.toml Cargo.toml
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY rust-toolchain.toml rust-toolchain.toml
...
@@ -32,7 +32,7 @@ COPY proto proto
...
@@ -32,7 +32,7 @@ COPY proto proto
COPY benchmark benchmark
COPY benchmark benchmark
COPY router router
COPY router router
COPY launcher launcher
COPY launcher launcher
RUN cargo build --release
RUN cargo build --
profile
release
-opt
# Text Generation Inference base image for Intel
# Text Generation Inference base image for Intel
...
@@ -78,11 +78,11 @@ ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mp
...
@@ -78,11 +78,11 @@ ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mp
ENV CCL_ZE_IPC_EXCHANGE=sockets
ENV CCL_ZE_IPC_EXCHANGE=sockets
# Install benchmarker
# Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
COPY --from=builder /usr/src/target/release
-opt
/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
# Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
COPY --from=builder /usr/src/target/release
-opt
/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
# Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
COPY --from=builder /usr/src/target/release
-opt
/text-generation-launcher /usr/local/bin/text-generation-launcher
# Final image
# Final image
FROM base
FROM base
...
...
benchmark/src/generation.rs
View file @
8aece3bd
...
@@ -155,6 +155,8 @@ async fn prefill(
...
@@ -155,6 +155,8 @@ async fn prefill(
ignore_eos_token
:
true
,
// Will not stop even if a eos token is generated
ignore_eos_token
:
true
,
// Will not stop even if a eos token is generated
}),
}),
top_n_tokens
:
top_n_tokens
.unwrap_or
(
0
),
top_n_tokens
:
top_n_tokens
.unwrap_or
(
0
),
blocks
:
vec!
[],
slots
:
vec!
[],
})
})
.collect
();
.collect
();
...
@@ -163,6 +165,7 @@ async fn prefill(
...
@@ -163,6 +165,7 @@ async fn prefill(
requests
,
requests
,
size
:
batch_size
,
size
:
batch_size
,
max_tokens
:
batch_size
*
(
sequence_length
+
decode_length
),
max_tokens
:
batch_size
*
(
sequence_length
+
decode_length
),
max_blocks
:
0
,
};
};
// Run prefill
// Run prefill
...
...
proto/v3/generate.proto
View file @
8aece3bd
...
@@ -130,6 +130,10 @@ message Request {
...
@@ -130,6 +130,10 @@ message Request {
bool
prefill_logprobs
=
6
;
bool
prefill_logprobs
=
6
;
/// Return most likely n tokens
/// Return most likely n tokens
uint32
top_n_tokens
=
7
;
uint32
top_n_tokens
=
7
;
/// Paged attention blocks
repeated
uint32
blocks
=
9
;
/// Paged attention slots
repeated
uint32
slots
=
10
;
}
}
message
Batch
{
message
Batch
{
...
@@ -141,6 +145,8 @@ message Batch {
...
@@ -141,6 +145,8 @@ message Batch {
uint32
size
=
3
;
uint32
size
=
3
;
/// Maximum number of tokens this batch will grow to
/// Maximum number of tokens this batch will grow to
uint32
max_tokens
=
4
;
uint32
max_tokens
=
4
;
/// Maximum number of Paged Attention blocks
uint32
max_blocks
=
5
;
}
}
message
CachedBatch
{
message
CachedBatch
{
...
...
router/client/src/v3/client.rs
View file @
8aece3bd
...
@@ -153,6 +153,9 @@ impl Client {
...
@@ -153,6 +153,9 @@ impl Client {
}),
}),
// We truncate the input on the server side to be sure that it has the correct size
// We truncate the input on the server side to be sure that it has the correct size
truncate
,
truncate
,
// Blocks and slots will be set on the server side if we use paged attention
blocks
:
vec!
[],
slots
:
vec!
[],
// Set sampling parameters to also take these ops into account in the max memory
// Set sampling parameters to also take these ops into account in the max memory
parameters
:
Some
(
NextTokenChooserParameters
{
parameters
:
Some
(
NextTokenChooserParameters
{
temperature
:
0.9
,
temperature
:
0.9
,
...
@@ -187,7 +190,8 @@ impl Client {
...
@@ -187,7 +190,8 @@ impl Client {
id
:
0
,
id
:
0
,
size
:
requests
.len
()
as
u32
,
size
:
requests
.len
()
as
u32
,
requests
,
requests
,
max_tokens
:
0
,
max_tokens
:
max_input_length
,
max_blocks
:
0
,
};
};
let
request
=
tonic
::
Request
::
new
(
WarmupRequest
{
let
request
=
tonic
::
Request
::
new
(
WarmupRequest
{
...
...
router/client/src/v3/sharded_client.rs
View file @
8aece3bd
...
@@ -241,12 +241,16 @@ impl Health for ShardedClient {
...
@@ -241,12 +241,16 @@ impl Health for ShardedClient {
ignore_eos_token
:
false
,
ignore_eos_token
:
false
,
}),
}),
top_n_tokens
:
0
,
top_n_tokens
:
0
,
// Block 0 is reserved for health checks
blocks
:
vec!
[
0
],
slots
:
(
0
..
16
)
.collect
(),
};
};
let
batch
=
Batch
{
let
batch
=
Batch
{
id
:
u64
::
MAX
,
id
:
u64
::
MAX
,
requests
:
vec!
[
liveness_request
],
requests
:
vec!
[
liveness_request
],
size
:
1
,
size
:
1
,
max_tokens
:
2
,
max_tokens
:
2
,
max_blocks
:
1
,
};
};
self
.clone
()
.prefill
(
batch
)
.await
?
;
self
.clone
()
.prefill
(
batch
)
.await
?
;
Ok
(())
Ok
(())
...
...
router/src/infer/v3/block_allocator.rs
0 → 100644
View file @
8aece3bd
use
std
::
cmp
::
min
;
use
tokio
::
sync
::{
mpsc
,
oneshot
};
#[derive(Debug,
Clone)]
pub
(
crate
)
struct
BlockAllocation
{
pub
blocks
:
Vec
<
u32
>
,
pub
slots
:
Vec
<
u32
>
,
block_allocator
:
BlockAllocator
,
}
impl
Drop
for
BlockAllocation
{
fn
drop
(
&
mut
self
)
{
self
.block_allocator
.free
(
self
.blocks
.clone
())
}
}
#[derive(Debug,
Clone)]
pub
(
crate
)
struct
BlockAllocator
{
/// Channel to communicate with the background task
block_allocator
:
mpsc
::
UnboundedSender
<
BlockAllocatorCommand
>
,
}
impl
BlockAllocator
{
pub
(
crate
)
fn
new
(
max_batch_total_tokens
:
u32
,
block_size
:
u32
,
window_size
:
Option
<
u32
>
,
)
->
Self
{
// Create channel
let
(
sender
,
receiver
)
=
mpsc
::
unbounded_channel
();
// Launch background queue task
tokio
::
spawn
(
block_allocator_task
(
max_batch_total_tokens
/
block_size
,
block_size
,
window_size
,
receiver
,
));
Self
{
block_allocator
:
sender
,
}
}
pub
(
crate
)
async
fn
allocate
(
&
self
,
tokens
:
u32
)
->
Option
<
BlockAllocation
>
{
let
(
response_sender
,
response_receiver
)
=
oneshot
::
channel
();
self
.block_allocator
.send
(
BlockAllocatorCommand
::
Allocate
{
tokens
,
response_sender
,
})
.unwrap
();
response_receiver
.await
.unwrap
()
.map
(|(
blocks
,
slots
)|
BlockAllocation
{
blocks
,
slots
,
block_allocator
:
self
.clone
(),
})
}
pub
(
crate
)
fn
free
(
&
self
,
blocks
:
Vec
<
u32
>
)
{
self
.block_allocator
.send
(
BlockAllocatorCommand
::
Free
{
blocks
})
.unwrap
();
}
}
async
fn
block_allocator_task
(
blocks
:
u32
,
block_size
:
u32
,
window_size
:
Option
<
u32
>
,
mut
receiver
:
mpsc
::
UnboundedReceiver
<
BlockAllocatorCommand
>
,
)
{
// Block 0 is reserved for health checks
let
mut
free_blocks
:
Vec
<
u32
>
=
(
1
..
blocks
)
.collect
();
while
let
Some
(
cmd
)
=
receiver
.recv
()
.await
{
match
cmd
{
BlockAllocatorCommand
::
Free
{
blocks
}
=>
free_blocks
.extend
(
blocks
),
BlockAllocatorCommand
::
Allocate
{
tokens
,
response_sender
,
}
=>
{
// Apply window size
let
(
required_blocks
,
repeats
)
=
{
let
(
tokens
,
repeats
)
=
match
window_size
{
None
=>
(
tokens
,
1
),
Some
(
window_size
)
=>
{
let
repeats
=
(
tokens
+
window_size
-
1
)
/
window_size
;
let
tokens
=
min
(
tokens
,
window_size
);
(
tokens
,
repeats
as
usize
)
}
};
// Pad to a multiple of block size
let
required_blocks
=
(
tokens
+
block_size
-
1
)
/
block_size
;
(
required_blocks
,
repeats
)
};
let
tokens
=
tokens
as
usize
;
let
allocation
=
if
required_blocks
>
free_blocks
.len
()
as
u32
{
None
}
else
{
let
blocks
=
free_blocks
.split_off
(
free_blocks
.len
()
-
required_blocks
as
usize
);
let
mut
slots
=
Vec
::
with_capacity
(
(
required_blocks
*
block_size
*
repeats
as
u32
)
as
usize
,
);
'slots
:
for
block_id
in
blocks
.repeat
(
repeats
)
.iter
()
{
for
s
in
(
block_id
*
block_size
)
..
((
block_id
+
1
)
*
block_size
)
{
slots
.push
(
s
);
if
slots
.len
()
==
tokens
{
break
'slots
;
}
}
}
Some
((
blocks
,
slots
))
};
response_sender
.send
(
allocation
)
.unwrap
();
}
}
}
}
#[derive(Debug)]
enum
BlockAllocatorCommand
{
Free
{
blocks
:
Vec
<
u32
>
,
},
Allocate
{
tokens
:
u32
,
response_sender
:
oneshot
::
Sender
<
Option
<
(
Vec
<
u32
>
,
Vec
<
u32
>
)
>>
,
},
}
router/src/infer/v3/mod.rs
View file @
8aece3bd
mod
block_allocator
;
mod
queue
;
mod
queue
;
mod
scheduler
;
mod
scheduler
;
...
...
router/src/infer/v3/queue.rs
View file @
8aece3bd
use
crate
::
infer
::{
InferError
,
InferStreamResponse
};
use
crate
::
infer
::
v3
::
block_allocator
::{
BlockAllocation
,
BlockAllocator
};
use
crate
::
infer
::
InferError
;
use
crate
::
infer
::
InferStreamResponse
;
use
crate
::
validation
::{
use
crate
::
validation
::{
ValidGenerateRequest
,
ValidGrammar
,
ValidParameters
,
ValidStoppingParameters
,
ValidGenerateRequest
,
ValidGrammar
,
ValidParameters
,
ValidStoppingParameters
,
};
};
use
nohash_hasher
::{
BuildNoHashHasher
,
IntMap
};
use
nohash_hasher
::{
BuildNoHashHasher
,
IntMap
};
use
std
::
cmp
::
min
;
use
std
::
cmp
::
{
max
,
min
}
;
use
std
::
collections
::
VecDeque
;
use
std
::
collections
::
VecDeque
;
use
text_generation_client
::
v3
::{
use
text_generation_client
::
v3
::{
Batch
,
GrammarType
,
NextTokenChooserParameters
,
Request
,
StoppingCriteriaParameters
,
Batch
,
GrammarType
,
NextTokenChooserParameters
,
Request
,
StoppingCriteriaParameters
,
};
};
use
text_generation_client
::{
ChunksToString
,
Input
};
use
text_generation_client
::
ChunksToString
;
use
text_generation_client
::
Input
;
use
tokio
::
sync
::{
mpsc
,
oneshot
};
use
tokio
::
sync
::{
mpsc
,
oneshot
};
use
tokio
::
time
::
Instant
;
use
tokio
::
time
::
Instant
;
use
tracing
::{
info_span
,
instrument
,
Span
};
use
tracing
::{
info_span
,
instrument
,
Instrument
,
Span
};
/// Queue entry
/// Queue entry
#[derive(Debug)]
#[derive(Debug)]
...
@@ -28,6 +31,8 @@ pub(crate) struct Entry {
...
@@ -28,6 +31,8 @@ pub(crate) struct Entry {
pub
queue_time
:
Instant
,
pub
queue_time
:
Instant
,
/// Instant when this entry was added to a batch
/// Instant when this entry was added to a batch
pub
batch_time
:
Option
<
Instant
>
,
pub
batch_time
:
Option
<
Instant
>
,
/// Block Allocation
pub
block_allocation
:
Option
<
BlockAllocation
>
,
}
}
/// Request Queue
/// Request Queue
...
@@ -43,6 +48,7 @@ impl Queue {
...
@@ -43,6 +48,7 @@ impl Queue {
block_size
:
u32
,
block_size
:
u32
,
window_size
:
Option
<
u32
>
,
window_size
:
Option
<
u32
>
,
speculate
:
u32
,
speculate
:
u32
,
max_batch_total_tokens
:
u32
,
)
->
Self
{
)
->
Self
{
// Create channel
// Create channel
let
(
queue_sender
,
queue_receiver
)
=
mpsc
::
unbounded_channel
();
let
(
queue_sender
,
queue_receiver
)
=
mpsc
::
unbounded_channel
();
...
@@ -53,12 +59,14 @@ impl Queue {
...
@@ -53,12 +59,14 @@ impl Queue {
block_size
,
block_size
,
window_size
,
window_size
,
speculate
,
speculate
,
max_batch_total_tokens
,
queue_receiver
,
queue_receiver
,
));
));
Self
{
queue_sender
}
Self
{
queue_sender
}
}
}
/// Append an entry to the queue
#[instrument(skip_all)]
#[instrument(skip_all)]
pub
(
crate
)
fn
append
(
&
self
,
entry
:
Entry
)
{
pub
(
crate
)
fn
append
(
&
self
,
entry
:
Entry
)
{
// Send append command to the background task managing the state
// Send append command to the background task managing the state
...
@@ -103,9 +111,16 @@ async fn queue_task(
...
@@ -103,9 +111,16 @@ async fn queue_task(
block_size
:
u32
,
block_size
:
u32
,
window_size
:
Option
<
u32
>
,
window_size
:
Option
<
u32
>
,
speculate
:
u32
,
speculate
:
u32
,
max_batch_total_tokens
:
u32
,
mut
receiver
:
mpsc
::
UnboundedReceiver
<
QueueCommand
>
,
mut
receiver
:
mpsc
::
UnboundedReceiver
<
QueueCommand
>
,
)
{
)
{
let
mut
state
=
State
::
new
(
requires_padding
,
block_size
,
window_size
,
speculate
);
let
mut
state
=
State
::
new
(
requires_padding
,
block_size
,
window_size
,
speculate
,
max_batch_total_tokens
,
);
while
let
Some
(
cmd
)
=
receiver
.recv
()
.await
{
while
let
Some
(
cmd
)
=
receiver
.recv
()
.await
{
match
cmd
{
match
cmd
{
...
@@ -120,12 +135,14 @@ async fn queue_task(
...
@@ -120,12 +135,14 @@ async fn queue_task(
token_budget
,
token_budget
,
response_sender
,
response_sender
,
span
,
span
,
}
=>
span
.in_scope
(||
{
}
=>
{
let
next_batch
=
let
next_batch
=
state
state
.next_batch
(
min_size
,
max_size
,
prefill_token_budget
,
token_budget
);
.next_batch
(
min_size
,
max_size
,
prefill_token_budget
,
token_budget
)
.instrument
(
span
)
.await
;
response_sender
.send
(
next_batch
)
.unwrap
();
response_sender
.send
(
next_batch
)
.unwrap
();
metrics
::
gauge!
(
"tgi_queue_size"
,
state
.entries
.len
()
as
f64
);
metrics
::
gauge!
(
"tgi_queue_size"
,
state
.entries
.len
()
as
f64
);
}
),
}
}
}
}
}
}
}
...
@@ -142,9 +159,6 @@ struct State {
...
@@ -142,9 +159,6 @@ struct State {
/// Id of the next batch
/// Id of the next batch
next_batch_id
:
u64
,
next_batch_id
:
u64
,
/// Whether the model is using padding
requires_padding
:
bool
,
/// Paged Attention block size
/// Paged Attention block size
block_size
:
u32
,
block_size
:
u32
,
...
@@ -153,6 +167,9 @@ struct State {
...
@@ -153,6 +167,9 @@ struct State {
/// Speculation amount
/// Speculation amount
speculate
:
u32
,
speculate
:
u32
,
/// Paged Attention Block Allocation
block_allocator
:
Option
<
BlockAllocator
>
,
}
}
impl
State
{
impl
State
{
...
@@ -161,15 +178,19 @@ impl State {
...
@@ -161,15 +178,19 @@ impl State {
block_size
:
u32
,
block_size
:
u32
,
window_size
:
Option
<
u32
>
,
window_size
:
Option
<
u32
>
,
speculate
:
u32
,
speculate
:
u32
,
max_batch_total_tokens
:
u32
,
)
->
Self
{
)
->
Self
{
let
block_allocator
=
(
!
requires_padding
)
.then
(||
BlockAllocator
::
new
(
max_batch_total_tokens
,
block_size
,
window_size
));
Self
{
Self
{
entries
:
VecDeque
::
with_capacity
(
128
),
entries
:
VecDeque
::
with_capacity
(
128
),
next_id
:
0
,
next_id
:
0
,
next_batch_id
:
0
,
next_batch_id
:
0
,
requires_padding
,
block_size
,
block_size
,
window_size
,
window_size
,
speculate
,
speculate
,
block_allocator
,
}
}
}
}
...
@@ -185,7 +206,7 @@ impl State {
...
@@ -185,7 +206,7 @@ impl State {
}
}
// Get the next batch
// Get the next batch
fn
next_batch
(
async
fn
next_batch
(
&
mut
self
,
&
mut
self
,
min_size
:
Option
<
usize
>
,
min_size
:
Option
<
usize
>
,
max_size
:
Option
<
usize
>
,
max_size
:
Option
<
usize
>
,
...
@@ -220,9 +241,10 @@ impl State {
...
@@ -220,9 +241,10 @@ impl State {
let
mut
max_input_length
=
0
;
let
mut
max_input_length
=
0
;
let
mut
prefill_tokens
:
u32
=
0
;
let
mut
prefill_tokens
:
u32
=
0
;
let
mut
decode_tokens
:
u32
=
0
;
let
mut
decode_tokens
:
u32
=
0
;
let
mut
max_blocks
=
0
;
// Pop entries starting from the front of the queue
// Pop entries starting from the front of the queue
while
let
Some
((
id
,
mut
entry
))
=
self
.entries
.pop_front
()
{
'entry_loop
:
while
let
Some
((
id
,
mut
entry
))
=
self
.entries
.pop_front
()
{
// Filter entries where the response receiver was dropped (== entries where the request
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
// was dropped by the client)
if
entry
.response_tx
.is_closed
()
{
if
entry
.response_tx
.is_closed
()
{
...
@@ -231,43 +253,67 @@ impl State {
...
@@ -231,43 +253,67 @@ impl State {
continue
;
continue
;
}
}
if
self
.requires_padding
{
let
block_allocation
=
match
&
self
.block_allocator
{
// We pad to max input length in the Python shards
None
=>
{
// We need to take these padding tokens into the equation
// We pad to max input length in the Python shards
max_input_length
=
max_input_length
.max
(
entry
.request.input_length
);
// We need to take these padding tokens into the equation
prefill_tokens
=
(
batch_requests
.len
()
+
1
)
as
u32
*
max_input_length
max_input_length
=
max_input_length
.max
(
entry
.request.input_length
);
}
else
{
prefill_tokens
=
(
batch_requests
.len
()
+
1
)
as
u32
*
max_input_length
;
// pad to block size
prefill_tokens
+=
((
entry
.request.input_length
+
self
.block_size
-
1
)
decode_tokens
+=
entry
.request.stopping_parameters.max_new_tokens
;
/
self
.block_size
)
let
total_tokens
=
prefill_tokens
+
decode_tokens
+
self
.speculate
;
*
self
.block_size
;
}
if
prefill_tokens
>
prefill_token_budget
||
total_tokens
>
token_budget
{
// Entry is over budget
if
self
.requires_padding
{
// Add it back to the front
decode_tokens
+=
entry
.request.stopping_parameters.max_new_tokens
;
tracing
::
debug!
(
"Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}"
,
self
.speculate
);
}
else
{
self
.entries
.push_front
((
id
,
entry
));
let
max_new_tokens
=
match
self
.window_size
{
break
'entry_loop
;
None
=>
entry
.request.stopping_parameters.max_new_tokens
,
}
Some
(
window_size
)
=>
min
(
None
window_size
.saturating_sub
(
entry
.request.input_length
),
}
entry
.request.stopping_parameters.max_new_tokens
,
Some
(
block_allocator
)
=>
{
),
prefill_tokens
+=
entry
.request.input_length
;
};
let
max_new_tokens
=
match
self
.window_size
{
None
=>
entry
.request.stopping_parameters.max_new_tokens
,
// pad to block size
Some
(
window_size
)
=>
min
(
decode_tokens
+=
window_size
.saturating_sub
(
entry
.request.input_length
),
((
max_new_tokens
+
self
.block_size
-
1
)
/
self
.block_size
)
*
self
.block_size
;
entry
.request.stopping_parameters.max_new_tokens
,
}
),
};
if
prefill_tokens
>
prefill_token_budget
decode_tokens
+=
max_new_tokens
;
||
(
prefill_tokens
+
decode_tokens
+
self
.speculate
)
>
token_budget
{
if
prefill_tokens
>
prefill_token_budget
// Entry is over budget
||
(
prefill_tokens
+
decode_tokens
+
self
.speculate
)
>
token_budget
// Add it back to the front
{
tracing
::
debug!
(
"Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}"
,
self
.speculate
);
// Entry is over budget
self
.entries
.push_front
((
id
,
entry
));
// Add it back to the front
break
;
tracing
::
debug!
(
"Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}"
,
self
.speculate
);
}
self
.entries
.push_front
((
id
,
entry
));
break
;
}
let
tokens
=
entry
.request.input_length
+
entry
.request.stopping_parameters.max_new_tokens
+
self
.speculate
-
1
;
match
block_allocator
.allocate
(
tokens
)
.await
{
None
=>
{
// Entry is over budget
// Add it back to the front
tracing
::
debug!
(
"Over budget: not enough free blocks"
);
self
.entries
.push_front
((
id
,
entry
));
break
'entry_loop
;
}
Some
(
block_allocation
)
=>
{
tracing
::
debug!
(
"Allocation: {block_allocation:?}"
);
max_blocks
=
max
(
max_blocks
,
block_allocation
.blocks
.len
()
as
u32
);
Some
(
block_allocation
)
}
}
}
};
tracing
::
debug!
(
"Accepting entry"
);
tracing
::
debug!
(
"Accepting entry"
);
// Create a new span to link the batch back to this entry
// Create a new span to link the batch back to this entry
...
@@ -278,13 +324,23 @@ impl State {
...
@@ -278,13 +324,23 @@ impl State {
// Update entry
// Update entry
entry
.temp_span
=
Some
(
entry_batch_span
);
entry
.temp_span
=
Some
(
entry_batch_span
);
let
(
blocks
,
slots
)
=
match
&
block_allocation
{
None
=>
(
Vec
::
new
(),
Vec
::
new
()),
Some
(
block_allocation
)
=>
(
block_allocation
.blocks
.clone
(),
block_allocation
.slots
.clone
(),
),
};
entry
.block_allocation
=
block_allocation
;
batch_requests
.push
(
Request
{
batch_requests
.push
(
Request
{
id
,
id
,
prefill_logprobs
:
entry
.request.decoder_input_details
,
prefill_logprobs
:
entry
.request.decoder_input_details
,
inputs
:
entry
.request.inputs
.chunks_to_string
(),
input_chunks
:
Some
(
Input
{
input_chunks
:
Some
(
Input
{
chunks
:
entry
.request.inputs
.clone
(),
chunks
:
entry
.request.inputs
.clone
(),
}),
}),
inputs
:
entry
.request.inputs
.chunks_to_string
(),
truncate
:
entry
.request.truncate
,
truncate
:
entry
.request.truncate
,
parameters
:
Some
(
NextTokenChooserParameters
::
from
(
parameters
:
Some
(
NextTokenChooserParameters
::
from
(
entry
.request.parameters
.clone
(),
entry
.request.parameters
.clone
(),
...
@@ -293,6 +349,8 @@ impl State {
...
@@ -293,6 +349,8 @@ impl State {
entry
.request.stopping_parameters
.clone
(),
entry
.request.stopping_parameters
.clone
(),
)),
)),
top_n_tokens
:
entry
.request.top_n_tokens
,
top_n_tokens
:
entry
.request.top_n_tokens
,
blocks
,
slots
,
});
});
// Set batch_time
// Set batch_time
entry
.batch_time
=
Some
(
Instant
::
now
());
entry
.batch_time
=
Some
(
Instant
::
now
());
...
@@ -335,6 +393,7 @@ impl State {
...
@@ -335,6 +393,7 @@ impl State {
requests
:
batch_requests
,
requests
:
batch_requests
,
size
,
size
,
max_tokens
:
(
prefill_tokens
+
decode_tokens
),
max_tokens
:
(
prefill_tokens
+
decode_tokens
),
max_blocks
,
};
};
// Increment batch id
// Increment batch id
self
.next_batch_id
+=
1
;
self
.next_batch_id
+=
1
;
...
@@ -438,13 +497,14 @@ mod tests {
...
@@ -438,13 +497,14 @@ mod tests {
temp_span
:
None
,
temp_span
:
None
,
queue_time
:
Instant
::
now
(),
queue_time
:
Instant
::
now
(),
batch_time
:
None
,
batch_time
:
None
,
block_allocation
:
None
,
};
};
(
entry
,
receiver_tx
)
(
entry
,
receiver_tx
)
}
}
#[test]
#[
tokio::
test]
fn
test_append
()
{
async
fn
test_append
()
{
let
mut
state
=
State
::
new
(
false
,
1
,
None
,
0
);
let
mut
state
=
State
::
new
(
false
,
1
,
None
,
0
,
16
);
let
(
entry
,
_
guard
)
=
default_entry
();
let
(
entry
,
_
guard
)
=
default_entry
();
assert_eq!
(
state
.next_id
,
0
);
assert_eq!
(
state
.next_id
,
0
);
...
@@ -458,23 +518,23 @@ mod tests {
...
@@ -458,23 +518,23 @@ mod tests {
assert_eq!
(
id
,
0
);
assert_eq!
(
id
,
0
);
}
}
#[test]
#[
tokio::
test]
fn
test_next_batch_empty
()
{
async
fn
test_next_batch_empty
()
{
let
mut
state
=
State
::
new
(
false
,
1
,
None
,
0
);
let
mut
state
=
State
::
new
(
false
,
1
,
None
,
0
,
16
);
assert
!
(
state
.next_batch
(
None
,
None
,
1
,
1
)
.is_none
());
assert
!
(
state
.next_batch
(
None
,
None
,
1
,
1
)
.
await
.
is_none
());
assert
!
(
state
.next_batch
(
Some
(
1
),
None
,
1
,
1
)
.is_none
());
assert
!
(
state
.next_batch
(
Some
(
1
),
None
,
1
,
1
)
.
await
.
is_none
());
}
}
#[test]
#[
tokio::
test]
fn
test_next_batch_min_size
()
{
async
fn
test_next_batch_min_size
()
{
let
mut
state
=
State
::
new
(
false
,
1
,
None
,
0
);
let
mut
state
=
State
::
new
(
false
,
1
,
None
,
0
,
16
);
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
state
.append
(
entry1
);
state
.append
(
entry1
);
state
.append
(
entry2
);
state
.append
(
entry2
);
let
(
entries
,
batch
,
_
)
=
state
.next_batch
(
None
,
None
,
2
,
2
)
.unwrap
();
let
(
entries
,
batch
,
_
)
=
state
.next_batch
(
None
,
None
,
2
,
2
)
.
await
.
unwrap
();
assert_eq!
(
entries
.len
(),
2
);
assert_eq!
(
entries
.len
(),
2
);
assert
!
(
entries
.contains_key
(
&
0
));
assert
!
(
entries
.contains_key
(
&
0
));
assert
!
(
entries
.contains_key
(
&
1
));
assert
!
(
entries
.contains_key
(
&
1
));
...
@@ -490,7 +550,7 @@ mod tests {
...
@@ -490,7 +550,7 @@ mod tests {
let
(
entry3
,
_
guard3
)
=
default_entry
();
let
(
entry3
,
_
guard3
)
=
default_entry
();
state
.append
(
entry3
);
state
.append
(
entry3
);
assert
!
(
state
.next_batch
(
Some
(
2
),
None
,
2
,
2
)
.is_none
());
assert
!
(
state
.next_batch
(
Some
(
2
),
None
,
2
,
2
)
.
await
.
is_none
());
assert_eq!
(
state
.next_id
,
3
);
assert_eq!
(
state
.next_id
,
3
);
assert_eq!
(
state
.entries
.len
(),
1
);
assert_eq!
(
state
.entries
.len
(),
1
);
...
@@ -498,15 +558,15 @@ mod tests {
...
@@ -498,15 +558,15 @@ mod tests {
assert_eq!
(
id
,
2
);
assert_eq!
(
id
,
2
);
}
}
#[test]
#[
tokio::
test]
fn
test_next_batch_max_size
()
{
async
fn
test_next_batch_max_size
()
{
let
mut
state
=
State
::
new
(
false
,
1
,
None
,
0
);
let
mut
state
=
State
::
new
(
false
,
1
,
None
,
0
,
16
);
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
state
.append
(
entry1
);
state
.append
(
entry1
);
state
.append
(
entry2
);
state
.append
(
entry2
);
let
(
entries
,
batch
,
_
)
=
state
.next_batch
(
None
,
Some
(
1
),
2
,
2
)
.unwrap
();
let
(
entries
,
batch
,
_
)
=
state
.next_batch
(
None
,
Some
(
1
),
2
,
2
)
.
await
.
unwrap
();
assert_eq!
(
entries
.len
(),
1
);
assert_eq!
(
entries
.len
(),
1
);
assert
!
(
entries
.contains_key
(
&
0
));
assert
!
(
entries
.contains_key
(
&
0
));
assert
!
(
entries
.get
(
&
0
)
.unwrap
()
.batch_time
.is_some
());
assert
!
(
entries
.get
(
&
0
)
.unwrap
()
.batch_time
.is_some
());
...
@@ -518,15 +578,15 @@ mod tests {
...
@@ -518,15 +578,15 @@ mod tests {
assert_eq!
(
state
.next_batch_id
,
1
);
assert_eq!
(
state
.next_batch_id
,
1
);
}
}
#[test]
#[
tokio::
test]
fn
test_next_batch_token_budget
()
{
async
fn
test_next_batch_token_budget
()
{
let
mut
state
=
State
::
new
(
false
,
1
,
None
,
0
);
let
mut
state
=
State
::
new
(
false
,
1
,
None
,
0
,
2
);
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
state
.append
(
entry1
);
state
.append
(
entry1
);
state
.append
(
entry2
);
state
.append
(
entry2
);
let
(
entries
,
batch
,
_
)
=
state
.next_batch
(
None
,
None
,
1
,
1
)
.unwrap
();
let
(
entries
,
batch
,
_
)
=
state
.next_batch
(
None
,
None
,
1
,
1
)
.
await
.
unwrap
();
assert_eq!
(
entries
.len
(),
1
);
assert_eq!
(
entries
.len
(),
1
);
assert
!
(
entries
.contains_key
(
&
0
));
assert
!
(
entries
.contains_key
(
&
0
));
assert_eq!
(
batch
.id
,
0
);
assert_eq!
(
batch
.id
,
0
);
...
@@ -539,7 +599,7 @@ mod tests {
...
@@ -539,7 +599,7 @@ mod tests {
let
(
entry3
,
_
guard3
)
=
default_entry
();
let
(
entry3
,
_
guard3
)
=
default_entry
();
state
.append
(
entry3
);
state
.append
(
entry3
);
let
(
entries
,
batch
,
_
)
=
state
.next_batch
(
None
,
None
,
3
,
3
)
.unwrap
();
let
(
entries
,
batch
,
_
)
=
state
.next_batch
(
None
,
None
,
3
,
3
)
.
await
.
unwrap
();
assert_eq!
(
entries
.len
(),
2
);
assert_eq!
(
entries
.len
(),
2
);
assert
!
(
entries
.contains_key
(
&
1
));
assert
!
(
entries
.contains_key
(
&
1
));
assert
!
(
entries
.contains_key
(
&
2
));
assert
!
(
entries
.contains_key
(
&
2
));
...
@@ -553,14 +613,14 @@ mod tests {
...
@@ -553,14 +613,14 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_queue_append
()
{
async
fn
test_queue_append
()
{
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
0
);
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
0
,
16
);
let
(
entry
,
_
guard
)
=
default_entry
();
let
(
entry
,
_
guard
)
=
default_entry
();
queue
.append
(
entry
);
queue
.append
(
entry
);
}
}
#[tokio::test]
#[tokio::test]
async
fn
test_queue_next_batch_empty
()
{
async
fn
test_queue_next_batch_empty
()
{
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
0
);
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
0
,
16
);
assert
!
(
queue
.next_batch
(
None
,
None
,
1
,
1
)
.await
.is_none
());
assert
!
(
queue
.next_batch
(
None
,
None
,
1
,
1
)
.await
.is_none
());
assert
!
(
queue
.next_batch
(
Some
(
1
),
None
,
1
,
1
)
.await
.is_none
());
assert
!
(
queue
.next_batch
(
Some
(
1
),
None
,
1
,
1
)
.await
.is_none
());
...
@@ -568,7 +628,7 @@ mod tests {
...
@@ -568,7 +628,7 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_queue_next_batch_min_size
()
{
async
fn
test_queue_next_batch_min_size
()
{
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
0
);
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
0
,
16
);
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
queue
.append
(
entry1
);
queue
.append
(
entry1
);
...
@@ -601,7 +661,7 @@ mod tests {
...
@@ -601,7 +661,7 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_queue_next_batch_max_size
()
{
async
fn
test_queue_next_batch_max_size
()
{
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
0
);
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
0
,
16
);
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
queue
.append
(
entry1
);
queue
.append
(
entry1
);
...
@@ -617,7 +677,7 @@ mod tests {
...
@@ -617,7 +677,7 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_queue_next_batch_token_budget
()
{
async
fn
test_queue_next_batch_token_budget
()
{
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
0
);
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
0
,
16
);
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
queue
.append
(
entry1
);
queue
.append
(
entry1
);
...
@@ -642,7 +702,7 @@ mod tests {
...
@@ -642,7 +702,7 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_queue_next_batch_token_speculate
()
{
async
fn
test_queue_next_batch_token_speculate
()
{
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
2
);
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
2
,
16
);
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry1
,
_
guard1
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
let
(
entry2
,
_
guard2
)
=
default_entry
();
queue
.append
(
entry1
);
queue
.append
(
entry1
);
...
@@ -661,7 +721,7 @@ mod tests {
...
@@ -661,7 +721,7 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_queue_next_batch_dropped_receiver
()
{
async
fn
test_queue_next_batch_dropped_receiver
()
{
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
0
);
let
queue
=
Queue
::
new
(
false
,
1
,
None
,
0
,
16
);
let
(
entry
,
_
)
=
default_entry
();
let
(
entry
,
_
)
=
default_entry
();
queue
.append
(
entry
);
queue
.append
(
entry
);
...
...
router/src/infer/v3/scheduler.rs
View file @
8aece3bd
...
@@ -39,7 +39,13 @@ impl SchedulerV3 {
...
@@ -39,7 +39,13 @@ impl SchedulerV3 {
speculate
:
u32
,
speculate
:
u32
,
generation_health
:
Arc
<
AtomicBool
>
,
generation_health
:
Arc
<
AtomicBool
>
,
)
->
Self
{
)
->
Self
{
let
queue
=
Queue
::
new
(
requires_padding
,
16
,
window_size
,
speculate
);
let
queue
=
Queue
::
new
(
requires_padding
,
16
,
window_size
,
speculate
,
max_batch_total_tokens
,
);
let
batching_task_notifier
=
Arc
::
new
(
Notify
::
new
());
let
batching_task_notifier
=
Arc
::
new
(
Notify
::
new
());
// Spawn batching background task that contains all the inference logic
// Spawn batching background task that contains all the inference logic
...
@@ -81,6 +87,7 @@ impl Scheduler for SchedulerV3 {
...
@@ -81,6 +87,7 @@ impl Scheduler for SchedulerV3 {
temp_span
:
None
,
temp_span
:
None
,
queue_time
:
Instant
::
now
(),
queue_time
:
Instant
::
now
(),
batch_time
:
None
,
batch_time
:
None
,
block_allocation
:
None
,
});
});
// Notify the background task that we have a new entry in the queue that needs
// Notify the background task that we have a new entry in the queue that needs
...
...
server/text_generation_server/models/cache_manager.py
deleted
100644 → 0
View file @
9ffe1f1e
import
math
import
torch
from
typing
import
Optional
,
List
,
Tuple
from
text_generation_server.utils.import_utils
import
SYSTEM
BLOCK_SIZE
:
int
=
16
# Will be set in warmup
CACHE_MANAGER
:
Optional
[
"CacheManager"
]
=
None
class
CacheManager
:
def
__init__
(
self
,
num_blocks
:
int
,
num_layers
:
int
,
num_heads
:
int
,
head_size
:
int
,
repeat_slots
:
bool
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
):
self
.
block_size
=
BLOCK_SIZE
self
.
num_blocks
=
num_blocks
self
.
repeat_slots
=
repeat_slots
element_size
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
if
SYSTEM
==
"xpu"
:
x
=
1
else
:
x
=
self
.
block_size
//
element_size
self
.
kv_cache
=
[
(
torch
.
empty
(
(
num_blocks
,
num_heads
,
head_size
//
x
,
self
.
block_size
,
x
),
dtype
=
dtype
,
device
=
device
,
),
torch
.
empty
(
(
num_blocks
,
num_heads
,
head_size
,
self
.
block_size
),
dtype
=
dtype
,
device
=
device
,
),
)
for
_
in
range
(
num_layers
)
]
self
.
free_block_mask
=
torch
.
ones
(
num_blocks
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
self
.
slots
=
torch
.
arange
(
0
,
num_blocks
*
self
.
block_size
,
dtype
=
torch
.
int64
).
view
(
num_blocks
,
self
.
block_size
)
def
allocate
(
self
,
needed_blocks_slots
:
List
[
Tuple
[
int
,
int
]],
blocks
:
int
,
max_blocks
:
int
,
device
:
torch
.
device
,
):
# Get free blocks indices by finding values in mask that are not set to 0
free_block_indices
=
self
.
free_block_mask
.
nonzero
()
if
blocks
>
len
(
free_block_indices
):
raise
RuntimeError
(
f
"Out of available cache blocks: asked
{
blocks
}
, only
{
len
(
free_block_indices
)
}
free blocks"
)
# Slice by the number of required blocks
block_indices
=
free_block_indices
[:
blocks
]
block_indices
=
block_indices
.
flatten
()
# Padded block tables
block_tables_tensor
=
torch
.
zeros
(
(
len
(
needed_blocks_slots
),
max_blocks
),
dtype
=
torch
.
int32
)
# Allocate paged attention blocks
cumulative_blocks
=
0
slots
=
[]
block_tables
=
[]
for
i
,
(
needed_blocks
,
needed_slots
)
in
enumerate
(
needed_blocks_slots
):
# Get allocated blocks for this sequence
allocated_blocks
=
block_indices
[
cumulative_blocks
:
cumulative_blocks
+
needed_blocks
]
# Get slots for the allocated blocks
all_slots
=
self
.
slots
[
allocated_blocks
].
flatten
()
# Repeat slots in the case of context sliding window
if
needed_slots
>
len
(
all_slots
)
and
self
.
repeat_slots
:
repeats
=
math
.
ceil
(
needed_slots
/
len
(
all_slots
))
all_slots
=
all_slots
.
repeat
(
repeats
)
allocated_slots
=
all_slots
[:
needed_slots
]
slots
.
append
(
allocated_slots
)
block_tables
.
append
(
allocated_blocks
.
tolist
())
block_tables_tensor
[
i
,
:
needed_blocks
]
=
allocated_blocks
cumulative_blocks
+=
needed_blocks
block_tables
=
block_tables
block_tables_tensor
=
block_tables_tensor
.
to
(
device
)
slots
=
torch
.
concat
(
slots
).
to
(
device
)
# Allocate the required number of blocks by setting the mask to 0
self
.
free_block_mask
[
block_indices
]
=
0
return
block_tables
,
block_tables_tensor
,
slots
def
free
(
self
,
block_indices
:
Optional
[
List
[
int
]]):
if
block_indices
is
not
None
and
block_indices
:
# Reset mask
self
.
free_block_mask
[
block_indices
]
=
1
def
set_cache_manager
(
num_blocks
:
int
,
num_layers
:
int
,
num_heads
:
int
,
head_size
:
int
,
repeat_slots
:
bool
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
)
->
CacheManager
:
global
CACHE_MANAGER
if
CACHE_MANAGER
is
not
None
:
del
CACHE_MANAGER
torch
.
cuda
.
empty_cache
()
CACHE_MANAGER
=
CacheManager
(
num_blocks
,
num_layers
,
num_heads
,
head_size
,
repeat_slots
,
dtype
,
device
)
return
CACHE_MANAGER
def
get_cache_manager
()
->
CacheManager
:
global
CACHE_MANAGER
if
CACHE_MANAGER
is
None
:
raise
RuntimeError
(
"cache manager was not initialized"
)
return
CACHE_MANAGER
server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py
View file @
8aece3bd
...
@@ -512,6 +512,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
...
@@ -512,6 +512,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
slots
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
max_s
:
int
,
prefill_cache_indices
:
Optional
[
torch
.
Tensor
],
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
...
...
server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py
View file @
8aece3bd
...
@@ -834,6 +834,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
...
@@ -834,6 +834,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
slots
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
max_s
:
int
,
prefill_cache_indices
:
Optional
[
torch
.
Tensor
],
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
...
...
server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
View file @
8aece3bd
...
@@ -458,6 +458,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
...
@@ -458,6 +458,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
slots
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
max_s
:
int
,
prefill_cache_indices
:
Optional
[
torch
.
Tensor
],
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
input_embeds
=
self
.
embed_tokens
(
input_ids
)
input_embeds
=
self
.
embed_tokens
(
input_ids
)
...
...
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
View file @
8aece3bd
...
@@ -388,6 +388,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
...
@@ -388,6 +388,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
slots
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
max_s
:
int
,
prefill_cache_indices
:
Optional
[
torch
.
Tensor
],
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
gpt_neox
(
hidden_states
=
self
.
gpt_neox
(
...
...
server/text_generation_server/models/custom_modeling/flash_phi_modeling.py
View file @
8aece3bd
...
@@ -398,6 +398,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
...
@@ -398,6 +398,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
slots
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
max_s
:
int
,
prefill_cache_indices
:
Optional
[
torch
.
Tensor
],
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
...
...
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
View file @
8aece3bd
...
@@ -670,6 +670,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
...
@@ -670,6 +670,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
slots
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
max_s
:
int
,
prefill_cache_indices
:
Optional
[
torch
.
Tensor
],
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
hidden_states
=
self
.
transformer
(
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
8aece3bd
...
@@ -482,6 +482,7 @@ class FlashSantacoderForCausalLM(nn.Module):
...
@@ -482,6 +482,7 @@ class FlashSantacoderForCausalLM(nn.Module):
slots
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
max_s
:
int
,
prefill_cache_indices
:
Optional
[
torch
.
Tensor
],
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
hidden_states
=
self
.
transformer
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment