Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
f9910d13
Unverified
Commit
f9910d13
authored
Oct 23, 2023
by
OlivierDehaene
Committed by
GitHub
Oct 23, 2023
Browse files
feat: remove flume (#1184)
parent
12590fdc
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
75 additions
and
96 deletions
+75
-96
Cargo.lock
Cargo.lock
+6
-38
router/Cargo.toml
router/Cargo.toml
+1
-1
router/client/src/client.rs
router/client/src/client.rs
+2
-3
router/src/infer.rs
router/src/infer.rs
+25
-33
router/src/queue.rs
router/src/queue.rs
+9
-9
router/src/validation.rs
router/src/validation.rs
+32
-12
No files found.
Cargo.lock
View file @
f9910d13
...
@@ -743,18 +743,6 @@ version = "1.0.1"
...
@@ -743,18 +743,6 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853"
checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853"
[[package]]
name = "flume"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181"
dependencies = [
"futures-core",
"futures-sink",
"nanorand",
"spin 0.9.8",
]
[[package]]
[[package]]
name = "fnv"
name = "fnv"
version = "1.0.7"
version = "1.0.7"
...
@@ -900,10 +888,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
...
@@ -900,10 +888,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427"
checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427"
dependencies = [
dependencies = [
"cfg-if",
"cfg-if",
"js-sys",
"libc",
"libc",
"wasi",
"wasi",
"wasm-bindgen",
]
]
[[package]]
[[package]]
...
@@ -1508,15 +1494,6 @@ dependencies = [
...
@@ -1508,15 +1494,6 @@ dependencies = [
"tracing",
"tracing",
]
]
[[package]]
name = "nanorand"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3"
dependencies = [
"getrandom",
]
[[package]]
[[package]]
name = "native-tls"
name = "native-tls"
version = "0.2.11"
version = "0.2.11"
...
@@ -2313,7 +2290,7 @@ dependencies = [
...
@@ -2313,7 +2290,7 @@ dependencies = [
"cc",
"cc",
"libc",
"libc",
"once_cell",
"once_cell",
"spin
0.5.2
",
"spin",
"untrusted",
"untrusted",
"web-sys",
"web-sys",
"winapi",
"winapi",
...
@@ -2678,15 +2655,6 @@ version = "0.5.2"
...
@@ -2678,15 +2655,6 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
dependencies = [
"lock_api",
]
[[package]]
[[package]]
name = "spm_precompiled"
name = "spm_precompiled"
version = "0.1.4"
version = "0.1.4"
...
@@ -2808,7 +2776,7 @@ dependencies = [
...
@@ -2808,7 +2776,7 @@ dependencies = [
[[package]]
[[package]]
name = "text-generation-benchmark"
name = "text-generation-benchmark"
version = "1.1.
0
"
version = "1.1.
1
"
dependencies = [
dependencies = [
"average",
"average",
"clap",
"clap",
...
@@ -2829,7 +2797,7 @@ dependencies = [
...
@@ -2829,7 +2797,7 @@ dependencies = [
[[package]]
[[package]]
name = "text-generation-client"
name = "text-generation-client"
version = "1.1.
0
"
version = "1.1.
1
"
dependencies = [
dependencies = [
"futures",
"futures",
"grpc-metadata",
"grpc-metadata",
...
@@ -2845,7 +2813,7 @@ dependencies = [
...
@@ -2845,7 +2813,7 @@ dependencies = [
[[package]]
[[package]]
name = "text-generation-launcher"
name = "text-generation-launcher"
version = "1.1.
0
"
version = "1.1.
1
"
dependencies = [
dependencies = [
"clap",
"clap",
"ctrlc",
"ctrlc",
...
@@ -2861,13 +2829,12 @@ dependencies = [
...
@@ -2861,13 +2829,12 @@ dependencies = [
[[package]]
[[package]]
name = "text-generation-router"
name = "text-generation-router"
version = "1.1.
0
"
version = "1.1.
1
"
dependencies = [
dependencies = [
"async-stream",
"async-stream",
"axum",
"axum",
"axum-tracing-opentelemetry",
"axum-tracing-opentelemetry",
"clap",
"clap",
"flume",
"futures",
"futures",
"hf-hub 0.3.1",
"hf-hub 0.3.1",
"init-tracing-opentelemetry",
"init-tracing-opentelemetry",
...
@@ -2885,6 +2852,7 @@ dependencies = [
...
@@ -2885,6 +2852,7 @@ dependencies = [
"thiserror",
"thiserror",
"tokenizers",
"tokenizers",
"tokio",
"tokio",
"tokio-stream",
"tower-http",
"tower-http",
"tracing",
"tracing",
"tracing-opentelemetry",
"tracing-opentelemetry",
...
...
router/Cargo.toml
View file @
f9910d13
...
@@ -20,7 +20,6 @@ axum = { version = "0.6.20", features = ["json"] }
...
@@ -20,7 +20,6 @@ axum = { version = "0.6.20", features = ["json"] }
axum-tracing-opentelemetry
=
"0.14.1"
axum-tracing-opentelemetry
=
"0.14.1"
text-generation-client
=
{
path
=
"client"
}
text-generation-client
=
{
path
=
"client"
}
clap
=
{
version
=
"4.4.5"
,
features
=
[
"derive"
,
"env"
]
}
clap
=
{
version
=
"4.4.5"
,
features
=
[
"derive"
,
"env"
]
}
flume
=
"0.11.0"
futures
=
"0.3.28"
futures
=
"0.3.28"
metrics
=
"0.21.1"
metrics
=
"0.21.1"
metrics-exporter-prometheus
=
{
version
=
"0.12.1"
,
features
=
[]
}
metrics-exporter-prometheus
=
{
version
=
"0.12.1"
,
features
=
[]
}
...
@@ -34,6 +33,7 @@ serde_json = "1.0.107"
...
@@ -34,6 +33,7 @@ serde_json = "1.0.107"
thiserror
=
"1.0.48"
thiserror
=
"1.0.48"
tokenizers
=
{
version
=
"0.14.0"
,
features
=
["http"]
}
tokenizers
=
{
version
=
"0.14.0"
,
features
=
["http"]
}
tokio
=
{
version
=
"1.32.0"
,
features
=
[
"rt"
,
"rt-multi-thread"
,
"parking_lot"
,
"signal"
,
"sync"
]
}
tokio
=
{
version
=
"1.32.0"
,
features
=
[
"rt"
,
"rt-multi-thread"
,
"parking_lot"
,
"signal"
,
"sync"
]
}
tokio-stream
=
"0.1.14"
tower-http
=
{
version
=
"0.4.4"
,
features
=
["cors"]
}
tower-http
=
{
version
=
"0.4.4"
,
features
=
["cors"]
}
tracing
=
"0.1.37"
tracing
=
"0.1.37"
tracing-opentelemetry
=
"0.21.0"
tracing-opentelemetry
=
"0.21.0"
...
...
router/client/src/client.rs
View file @
f9910d13
...
@@ -107,15 +107,14 @@ impl Client {
...
@@ -107,15 +107,14 @@ impl Client {
)
->
Result
<
Option
<
u32
>>
{
)
->
Result
<
Option
<
u32
>>
{
let
mut
n_tokens
=
0
;
let
mut
n_tokens
=
0
;
let
mut
requests
=
Vec
::
new
();
let
mut
requests
=
Vec
::
new
();
let
mut
truncate
=
0
;
// Create requests
// Create requests
while
n_tokens
<
max_prefill_tokens
{
while
n_tokens
<
max_prefill_tokens
{
truncate
=
min
(
max_input_length
,
max_prefill_tokens
-
n_tokens
);
let
truncate
=
min
(
max_input_length
,
max_prefill_tokens
-
n_tokens
);
requests
.push
(
Request
{
requests
.push
(
Request
{
id
:
0
,
id
:
0
,
// We truncate the input on the server side to be sure that it has the correct size
// We truncate the input on the server side to be sure that it has the correct size
inputs
:
"_test "
.to_string
()
.repeat
(
max_input_length
as
usize
),
inputs
:
"_test "
.to_string
()
.repeat
(
max_input_length
as
usize
),
truncate
:
truncate
,
truncate
,
// Set sampling parameters to also take these ops into account in the max memory
// Set sampling parameters to also take these ops into account in the max memory
parameters
:
Some
(
NextTokenChooserParameters
{
parameters
:
Some
(
NextTokenChooserParameters
{
temperature
:
0.9
,
temperature
:
0.9
,
...
...
router/src/infer.rs
View file @
f9910d13
...
@@ -2,22 +2,21 @@
...
@@ -2,22 +2,21 @@
use
crate
::
validation
::{
Validation
,
ValidationError
};
use
crate
::
validation
::{
Validation
,
ValidationError
};
use
crate
::{
Entry
,
Queue
,
Token
};
use
crate
::{
Entry
,
Queue
,
Token
};
use
crate
::{
GenerateRequest
,
PrefillToken
};
use
crate
::{
GenerateRequest
,
PrefillToken
};
use
flume
::
r
#
async
::
RecvStream
;
use
flume
::
SendTimeoutError
;
use
futures
::
future
::
try_join_all
;
use
futures
::
future
::
try_join_all
;
use
futures
::
stream
::
StreamExt
;
use
nohash_hasher
::
IntMap
;
use
nohash_hasher
::
IntMap
;
use
std
::
sync
::{
use
std
::
sync
::{
atomic
::{
AtomicBool
,
Ordering
},
atomic
::{
AtomicBool
,
Ordering
},
Arc
,
Arc
,
};
};
use
std
::
time
::
Duration
;
use
text_generation_client
::{
use
text_generation_client
::{
Batch
,
CachedBatch
,
ClientError
,
GeneratedText
,
Generation
,
PrefillTokens
,
ShardedClient
,
Batch
,
CachedBatch
,
ClientError
,
GeneratedText
,
Generation
,
PrefillTokens
,
ShardedClient
,
};
};
use
thiserror
::
Error
;
use
thiserror
::
Error
;
use
tokio
::
sync
::{
Notify
,
OwnedSemaphorePermit
,
Semaphore
,
TryAcquireError
};
use
tokio
::
sync
::
mpsc
::
error
::
SendError
;
use
tokio
::
sync
::{
mpsc
,
Notify
,
OwnedSemaphorePermit
,
Semaphore
,
TryAcquireError
};
use
tokio
::
time
::
Instant
;
use
tokio
::
time
::
Instant
;
use
tokio_stream
::
wrappers
::
UnboundedReceiverStream
;
use
tokio_stream
::
StreamExt
;
use
tracing
::{
info_span
,
instrument
,
Instrument
,
Span
};
use
tracing
::{
info_span
,
instrument
,
Instrument
,
Span
};
/// Inference struct
/// Inference struct
...
@@ -90,7 +89,7 @@ impl Infer {
...
@@ -90,7 +89,7 @@ impl Infer {
)
->
Result
<
)
->
Result
<
(
(
OwnedSemaphorePermit
,
OwnedSemaphorePermit
,
Recv
Stream
<
Result
<
InferStreamResponse
,
InferError
>>
,
UnboundedReceiver
Stream
<
Result
<
InferStreamResponse
,
InferError
>>
,
),
),
InferError
,
InferError
,
>
{
>
{
...
@@ -113,7 +112,7 @@ impl Infer {
...
@@ -113,7 +112,7 @@ impl Infer {
})
?
;
})
?
;
// MPSC channel to communicate with the background batching task
// MPSC channel to communicate with the background batching task
let
(
response_tx
,
response_rx
)
=
flume
::
unbounded
();
let
(
response_tx
,
response_rx
)
=
mpsc
::
unbounded
_channel
();
// Append the request to the queue
// Append the request to the queue
self
.queue
.append
(
Entry
{
self
.queue
.append
(
Entry
{
...
@@ -130,7 +129,7 @@ impl Infer {
...
@@ -130,7 +129,7 @@ impl Infer {
self
.shared.batching_task
.notify_one
();
self
.shared.batching_task
.notify_one
();
// Return stream
// Return stream
Ok
((
permit
,
response_rx
.into_stream
(
)))
Ok
((
permit
,
UnboundedReceiverStream
::
new
(
response_rx
)))
}
}
/// Add a new request to the queue and return a InferResponse
/// Add a new request to the queue and return a InferResponse
...
@@ -493,10 +492,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
...
@@ -493,10 +492,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// If the receive an error from the Flume channel, it means that the client dropped the
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
// request and we need to stop generating hence why we unwrap_or(true)
let
stopped
=
send_responses
(
generation
,
entry
)
.map_err
(|
err
|
{
let
stopped
=
send_responses
(
generation
,
entry
)
.map_err
(|
err
|
{
if
let
SendTimeoutError
::
Timeout
(
_
)
=
*
err
{
tracing
::
error!
(
"Entry response channel error."
);
tracing
::
error!
(
"Entry response channel timed out."
)
}
metrics
::
increment_counter!
(
"tgi_request_failure"
,
"err"
=>
"dropped"
);
metrics
::
increment_counter!
(
"tgi_request_failure"
,
"err"
=>
"dropped"
);
err
err
})
.unwrap_or
(
true
);
})
.unwrap_or
(
true
);
...
@@ -510,9 +506,10 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
...
@@ -510,9 +506,10 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
fn
send_responses
(
fn
send_responses
(
generation
:
Generation
,
generation
:
Generation
,
entry
:
&
Entry
,
entry
:
&
Entry
,
)
->
Result
<
bool
,
Box
<
Send
Timeout
Error
<
Result
<
InferStreamResponse
,
InferError
>>>>
{
)
->
Result
<
bool
,
Box
<
SendError
<
Result
<
InferStreamResponse
,
InferError
>>>>
{
// Return directly if the channel is disconnected
// Return directly if the channel is disconnected
if
entry
.response_tx
.is_disconnected
()
{
if
entry
.response_tx
.is_closed
()
{
metrics
::
increment_counter!
(
"tgi_request_failure"
,
"err"
=>
"dropped"
);
return
Ok
(
true
);
return
Ok
(
true
);
}
}
...
@@ -520,10 +517,9 @@ fn send_responses(
...
@@ -520,10 +517,9 @@ fn send_responses(
if
let
Some
(
prefill_tokens
)
=
generation
.prefill_tokens
{
if
let
Some
(
prefill_tokens
)
=
generation
.prefill_tokens
{
// Send message
// Send message
entry
.response_tx
.send_timeout
(
entry
Ok
(
InferStreamResponse
::
Prefill
(
prefill_tokens
)),
.response_tx
Duration
::
from_millis
(
10
),
.send
(
Ok
(
InferStreamResponse
::
Prefill
(
prefill_tokens
)))
?
;
)
?
;
}
}
// Create last Token
// Create last Token
...
@@ -558,22 +554,18 @@ fn send_responses(
...
@@ -558,22 +554,18 @@ fn send_responses(
// Generation has ended
// Generation has ended
stopped
=
true
;
stopped
=
true
;
// Send message
// Send message
entry
.response_tx
.send_timeout
(
entry
.response_tx
.send
(
Ok
(
InferStreamResponse
::
End
{
Ok
(
InferStreamResponse
::
End
{
token
,
token
,
top_tokens
,
top_tokens
,
generated_text
,
generated_text
,
queued
:
entry
.queue_time
,
queued
:
entry
.queue_time
,
start
:
entry
.batch_time
.unwrap
(),
start
:
entry
.batch_time
.unwrap
(),
}))
?
;
}),
Duration
::
from_millis
(
10
),
)
?
;
}
else
{
}
else
{
// Send message
// Send message
entry
.response_tx
.send_timeout
(
entry
Ok
(
InferStreamResponse
::
Intermediate
{
token
,
top_tokens
}),
.response_tx
Duration
::
from_millis
(
10
),
.send
(
Ok
(
InferStreamResponse
::
Intermediate
{
token
,
top_tokens
}))
?
;
)
?
;
}
}
Ok
(
stopped
)
Ok
(
stopped
)
}
}
...
@@ -591,7 +583,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
...
@@ -591,7 +583,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
// unwrap_or is valid here as we don't care if the receiver is gone.
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
entry
.response_tx
.response_tx
.send
_timeout
(
Err
(
err
),
Duration
::
from_millis
(
10
))
.send
(
Err
(
err
))
.unwrap_or
(());
.unwrap_or
(());
});
});
}
}
...
...
router/src/queue.rs
View file @
f9910d13
...
@@ -5,7 +5,7 @@ use nohash_hasher::{BuildNoHashHasher, IntMap};
...
@@ -5,7 +5,7 @@ use nohash_hasher::{BuildNoHashHasher, IntMap};
use
std
::
cmp
::
min
;
use
std
::
cmp
::
min
;
use
std
::
collections
::
VecDeque
;
use
std
::
collections
::
VecDeque
;
use
text_generation_client
::{
Batch
,
Request
};
use
text_generation_client
::{
Batch
,
Request
};
use
tokio
::
sync
::
oneshot
;
use
tokio
::
sync
::
{
mpsc
,
oneshot
}
;
use
tokio
::
time
::
Instant
;
use
tokio
::
time
::
Instant
;
use
tracing
::{
info_span
,
instrument
,
Span
};
use
tracing
::{
info_span
,
instrument
,
Span
};
...
@@ -15,7 +15,7 @@ pub(crate) struct Entry {
...
@@ -15,7 +15,7 @@ pub(crate) struct Entry {
/// Request
/// Request
pub
request
:
ValidGenerateRequest
,
pub
request
:
ValidGenerateRequest
,
/// Response sender to communicate between the Infer struct and the batching_task
/// Response sender to communicate between the Infer struct and the batching_task
pub
response_tx
:
flume
::
Sender
<
Result
<
InferStreamResponse
,
InferError
>>
,
pub
response_tx
:
mpsc
::
Unbounded
Sender
<
Result
<
InferStreamResponse
,
InferError
>>
,
/// Span that will live as long as entry
/// Span that will live as long as entry
pub
span
:
Span
,
pub
span
:
Span
,
/// Temporary span used as a guard when logging inference, wait times...
/// Temporary span used as a guard when logging inference, wait times...
...
@@ -30,13 +30,13 @@ pub(crate) struct Entry {
...
@@ -30,13 +30,13 @@ pub(crate) struct Entry {
#[derive(Debug,
Clone)]
#[derive(Debug,
Clone)]
pub
(
crate
)
struct
Queue
{
pub
(
crate
)
struct
Queue
{
/// Channel to communicate with the background queue task
/// Channel to communicate with the background queue task
queue_sender
:
flume
::
Sender
<
QueueCommand
>
,
queue_sender
:
mpsc
::
Unbounded
Sender
<
QueueCommand
>
,
}
}
impl
Queue
{
impl
Queue
{
pub
(
crate
)
fn
new
(
requires_padding
:
bool
,
block_size
:
u32
,
window_size
:
Option
<
u32
>
)
->
Self
{
pub
(
crate
)
fn
new
(
requires_padding
:
bool
,
block_size
:
u32
,
window_size
:
Option
<
u32
>
)
->
Self
{
// Create channel
// Create channel
let
(
queue_sender
,
queue_receiver
)
=
flume
::
unbounded
();
let
(
queue_sender
,
queue_receiver
)
=
mpsc
::
unbounded
_channel
();
// Launch background queue task
// Launch background queue task
tokio
::
spawn
(
queue_task
(
tokio
::
spawn
(
queue_task
(
...
@@ -91,11 +91,11 @@ async fn queue_task(
...
@@ -91,11 +91,11 @@ async fn queue_task(
requires_padding
:
bool
,
requires_padding
:
bool
,
block_size
:
u32
,
block_size
:
u32
,
window_size
:
Option
<
u32
>
,
window_size
:
Option
<
u32
>
,
receiver
:
flume
::
Receiver
<
QueueCommand
>
,
mut
receiver
:
mpsc
::
Unbounded
Receiver
<
QueueCommand
>
,
)
{
)
{
let
mut
state
=
State
::
new
(
requires_padding
,
block_size
,
window_size
);
let
mut
state
=
State
::
new
(
requires_padding
,
block_size
,
window_size
);
while
let
Ok
(
cmd
)
=
receiver
.recv
_async
()
.await
{
while
let
Some
(
cmd
)
=
receiver
.recv
()
.await
{
match
cmd
{
match
cmd
{
QueueCommand
::
Append
(
entry
,
span
)
=>
{
QueueCommand
::
Append
(
entry
,
span
)
=>
{
span
.in_scope
(||
state
.append
(
*
entry
));
span
.in_scope
(||
state
.append
(
*
entry
));
...
@@ -195,7 +195,7 @@ impl State {
...
@@ -195,7 +195,7 @@ impl State {
while
let
Some
((
id
,
mut
entry
))
=
self
.entries
.pop_front
()
{
while
let
Some
((
id
,
mut
entry
))
=
self
.entries
.pop_front
()
{
// Filter entries where the response receiver was dropped (== entries where the request
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
// was dropped by the client)
if
entry
.response_tx
.is_
disconnect
ed
()
{
if
entry
.response_tx
.is_
clos
ed
()
{
metrics
::
increment_counter!
(
"tgi_request_failure"
,
"err"
=>
"dropped"
);
metrics
::
increment_counter!
(
"tgi_request_failure"
,
"err"
=>
"dropped"
);
continue
;
continue
;
}
}
...
@@ -321,9 +321,9 @@ mod tests {
...
@@ -321,9 +321,9 @@ mod tests {
fn
default_entry
()
->
(
fn
default_entry
()
->
(
Entry
,
Entry
,
flume
::
Receiver
<
Result
<
InferStreamResponse
,
InferError
>>
,
mpsc
::
Unbounded
Receiver
<
Result
<
InferStreamResponse
,
InferError
>>
,
)
{
)
{
let
(
response_tx
,
receiver_tx
)
=
flume
::
unbounded
();
let
(
response_tx
,
receiver_tx
)
=
mpsc
::
unbounded
_channel
();
let
entry
=
Entry
{
let
entry
=
Entry
{
request
:
ValidGenerateRequest
{
request
:
ValidGenerateRequest
{
...
...
router/src/validation.rs
View file @
f9910d13
...
@@ -6,6 +6,7 @@ use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParamet
...
@@ -6,6 +6,7 @@ use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParamet
use
thiserror
::
Error
;
use
thiserror
::
Error
;
use
tokenizers
::
tokenizer
::
Tokenizer
;
use
tokenizers
::
tokenizer
::
Tokenizer
;
use
tokenizers
::
TruncationDirection
;
use
tokenizers
::
TruncationDirection
;
use
tokio
::
sync
::
mpsc
;
use
tokio
::
sync
::
oneshot
;
use
tokio
::
sync
::
oneshot
;
use
tracing
::{
instrument
,
Span
};
use
tracing
::{
instrument
,
Span
};
...
@@ -19,7 +20,7 @@ pub struct Validation {
...
@@ -19,7 +20,7 @@ pub struct Validation {
max_input_length
:
usize
,
max_input_length
:
usize
,
max_total_tokens
:
usize
,
max_total_tokens
:
usize
,
/// Channel to communicate with the background tokenization task
/// Channel to communicate with the background tokenization task
sender
:
Option
<
flume
::
Sender
<
TokenizerRequest
>>
,
sender
:
Option
<
mpsc
::
Unbounded
Sender
<
TokenizerRequest
>>
,
}
}
impl
Validation
{
impl
Validation
{
...
@@ -34,19 +35,25 @@ impl Validation {
...
@@ -34,19 +35,25 @@ impl Validation {
)
->
Self
{
)
->
Self
{
// If we have a fast tokenizer
// If we have a fast tokenizer
let
sender
=
if
let
Some
(
tokenizer
)
=
tokenizer
{
let
sender
=
if
let
Some
(
tokenizer
)
=
tokenizer
{
// Create channel
// Create round robin channel
let
(
validation_sender
,
validation_receiver
)
=
flume
::
unbounded
();
let
(
validation_sender
,
validation_round_robin_receiver
)
=
mpsc
::
unbounded_channel
();
let
mut
senders
=
Vec
::
with_capacity
(
workers
);
// Create workers
// Create workers
for
_
in
0
..
workers
{
for
_
in
0
..
workers
{
let
tokenizer_clone
=
tokenizer
.clone
();
let
tokenizer_clone
=
tokenizer
.clone
();
let
receiver_clone
=
validation_receiver
.clone
();
let
(
tokenizer_sender
,
tokenizer_receiver
)
=
mpsc
::
unbounded_channel
();
senders
.push
(
tokenizer_sender
);
// Spawn worker
// Spawn worker
tokio
::
task
::
spawn_blocking
(
move
||
{
tokio
::
task
::
spawn_blocking
(
move
||
{
tokenizer_worker
(
tokenizer_clone
,
receiver
_clone
)
tokenizer_worker
(
tokenizer_clone
,
tokenizer_
receiver
)
});
});
}
}
// Create tokenization round robin task
tokio
::
spawn
(
round_robin_task
(
validation_round_robin_receiver
,
senders
));
Some
(
validation_sender
)
Some
(
validation_sender
)
}
else
{
}
else
{
None
None
...
@@ -118,12 +125,10 @@ impl Validation {
...
@@ -118,12 +125,10 @@ impl Validation {
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
let
max_new_tokens
:
u32
=
if
let
Some
(
max_new_tokens
)
=
max_new_tokens
{
let
max_new_tokens
:
u32
=
if
let
Some
(
max_new_tokens
)
=
max_new_tokens
{
max_new_tokens
max_new_tokens
}
else
if
let
Some
(
truncate
)
=
truncate
{
self
.max_total_tokens
.saturating_sub
(
truncate
)
as
u32
}
else
{
}
else
{
if
let
Some
(
truncate
)
=
truncate
{
return
Err
(
ValidationError
::
UnsetMaxNewTokens
);
self
.max_total_tokens
.saturating_sub
(
truncate
)
as
u32
}
else
{
return
Err
(
ValidationError
::
UnsetMaxNewTokens
);
}
};
};
let
input_length
=
truncate
.unwrap_or
(
self
.max_input_length
);
let
input_length
=
truncate
.unwrap_or
(
self
.max_input_length
);
...
@@ -309,10 +314,25 @@ impl Validation {
...
@@ -309,10 +314,25 @@ impl Validation {
}
}
}
}
/// Round robin tokenization task
async
fn
round_robin_task
(
mut
receiver
:
mpsc
::
UnboundedReceiver
<
TokenizerRequest
>
,
senders
:
Vec
<
mpsc
::
UnboundedSender
<
TokenizerRequest
>>
,
)
{
loop
{
for
sender
in
&
senders
{
match
receiver
.recv
()
.await
{
None
=>
return
,
Some
(
request
)
=>
sender
.send
(
request
)
.unwrap
(),
};
}
}
}
/// Start tokenization workers
/// Start tokenization workers
fn
tokenizer_worker
(
tokenizer
:
Tokenizer
,
receiver
:
flume
::
Receiver
<
TokenizerRequest
>
)
{
fn
tokenizer_worker
(
tokenizer
:
Tokenizer
,
mut
receiver
:
mpsc
::
Unbounded
Receiver
<
TokenizerRequest
>
)
{
// Loop over requests
// Loop over requests
while
let
Ok
(((
inputs
,
truncate
),
response_tx
,
parent_span
))
=
receiver
.recv
()
{
while
let
Some
(((
inputs
,
truncate
),
response_tx
,
parent_span
))
=
receiver
.
blocking_
recv
()
{
parent_span
.in_scope
(||
{
parent_span
.in_scope
(||
{
response_tx
response_tx
.send
(
prepare_input
(
inputs
,
truncate
,
&
tokenizer
))
.send
(
prepare_input
(
inputs
,
truncate
,
&
tokenizer
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment