Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
92c1ecd0
Commit
92c1ecd0
authored
Oct 17, 2022
by
Olivier Dehaene
Browse files
feat: Add arguments to CLI
parent
5e5d8766
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
163 additions
and
35 deletions
+163
-35
Dockerfile
Dockerfile
+1
-1
README.md
README.md
+0
-1
router/Cargo.lock
router/Cargo.lock
+100
-2
router/Cargo.toml
router/Cargo.toml
+3
-0
router/src/batcher.rs
router/src/batcher.rs
+17
-11
router/src/lib.rs
router/src/lib.rs
+2
-2
router/src/main.rs
router/src/main.rs
+31
-5
router/src/server.rs
router/src/server.rs
+6
-10
router/src/validation.rs
router/src/validation.rs
+3
-3
No files found.
Dockerfile
View file @
92c1ecd0
...
@@ -9,7 +9,7 @@ WORKDIR /usr/src/router
...
@@ -9,7 +9,7 @@ WORKDIR /usr/src/router
RUN
cargo
install
--path
.
RUN
cargo
install
--path
.
FROM
nvidia/cuda:11.
8.0
-devel-ubuntu
22
.04
FROM
nvidia/cuda:11.
6.1
-devel-ubuntu
18
.04
ENV
LANG=C.UTF-8 \
ENV
LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
LC_ALL=C.UTF-8 \
...
...
README.md
View file @
92c1ecd0
...
@@ -43,7 +43,6 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire
...
@@ -43,7 +43,6 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire
## TODO:
## TODO:
-
[ ] Add batching args to router CLI
-
[ ] Add docstrings + comments everywhere as the codebase is fairly complicated
-
[ ] Add docstrings + comments everywhere as the codebase is fairly complicated
-
[ ] Add tests
-
[ ] Add tests
-
[ ] Add shutdown logic in router and server
-
[ ] Add shutdown logic in router and server
...
...
router/Cargo.lock
View file @
92c1ecd0
...
@@ -253,6 +253,43 @@ dependencies = [
...
@@ -253,6 +253,43 @@ dependencies = [
"vec_map",
"vec_map",
]
]
[[package]]
name = "clap"
version = "4.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bf8832993da70a4c6d13c581f4463c2bdda27b9bf1c5498dc4365543abe6d6f"
dependencies = [
"atty",
"bitflags",
"clap_derive",
"clap_lex",
"once_cell",
"strsim 0.10.0",
"termcolor",
]
[[package]]
name = "clap_derive"
version = "4.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c42f169caba89a7d512b5418b09864543eeb4d497416c917d7137863bd2076ad"
dependencies = [
"heck 0.4.0",
"proc-macro-error",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "clap_lex"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d4198f73e42b4936b35b5bb248d81d2b595ecb170da0bac7655c54eedfa8da8"
dependencies = [
"os_str_bytes",
]
[[package]]
[[package]]
name = "console"
name = "console"
version = "0.15.2"
version = "0.15.2"
...
@@ -701,6 +738,12 @@ dependencies = [
...
@@ -701,6 +738,12 @@ dependencies = [
"unicode-segmentation",
"unicode-segmentation",
]
]
[[package]]
name = "heck"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9"
[[package]]
[[package]]
name = "hermit-abi"
name = "hermit-abi"
version = "0.1.19"
version = "0.1.19"
...
@@ -1136,6 +1179,12 @@ dependencies = [
...
@@ -1136,6 +1179,12 @@ dependencies = [
"vcpkg",
"vcpkg",
]
]
[[package]]
name = "os_str_bytes"
version = "6.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff"
[[package]]
[[package]]
name = "parking_lot"
name = "parking_lot"
version = "0.12.1"
version = "0.12.1"
...
@@ -1225,6 +1274,30 @@ version = "0.2.16"
...
@@ -1225,6 +1274,30 @@ version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
[[package]]
name = "proc-macro-error"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
dependencies = [
"proc-macro-error-attr",
"proc-macro2",
"quote",
"syn",
"version_check",
]
[[package]]
name = "proc-macro-error-attr"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
dependencies = [
"proc-macro2",
"quote",
"version_check",
]
[[package]]
[[package]]
name = "proc-macro2"
name = "proc-macro2"
version = "1.0.46"
version = "1.0.46"
...
@@ -1251,7 +1324,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
...
@@ -1251,7 +1324,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62941722fb675d463659e49c4f3fe1fe792ff24fe5bbaa9c08cd3b98a1c354f5"
checksum = "62941722fb675d463659e49c4f3fe1fe792ff24fe5bbaa9c08cd3b98a1c354f5"
dependencies = [
dependencies = [
"bytes",
"bytes",
"heck",
"heck
0.3.3
",
"itertools 0.10.5",
"itertools 0.10.5",
"lazy_static",
"lazy_static",
"log",
"log",
...
@@ -1601,6 +1674,12 @@ version = "0.9.3"
...
@@ -1601,6 +1674,12 @@ version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
checksum = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
[[package]]
name = "strsim"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
[[package]]
name = "syn"
name = "syn"
version = "1.0.101"
version = "1.0.101"
...
@@ -1643,6 +1722,15 @@ dependencies = [
...
@@ -1643,6 +1722,15 @@ dependencies = [
"winapi",
"winapi",
]
]
[[package]]
name = "termcolor"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bab24d30b911b2376f3a13cc2cd443142f0c81dda04c118693e35b3835757755"
dependencies = [
"winapi-util",
]
[[package]]
[[package]]
name = "terminal_size"
name = "terminal_size"
version = "0.1.17"
version = "0.1.17"
...
@@ -1659,6 +1747,7 @@ version = "0.1.0"
...
@@ -1659,6 +1747,7 @@ version = "0.1.0"
dependencies = [
dependencies = [
"axum",
"axum",
"bloom-inference-client",
"bloom-inference-client",
"clap 4.0.15",
"futures",
"futures",
"parking_lot",
"parking_lot",
"serde",
"serde",
...
@@ -1742,7 +1831,7 @@ checksum = "3d7b08ede6742d7a59d58c71da8a6fa21bedc433dca2e855e439274d08df1170"
...
@@ -1742,7 +1831,7 @@ checksum = "3d7b08ede6742d7a59d58c71da8a6fa21bedc433dca2e855e439274d08df1170"
dependencies = [
dependencies = [
"aho-corasick",
"aho-corasick",
"cached-path",
"cached-path",
"clap",
"clap
2.34.0
",
"derive_builder",
"derive_builder",
"dirs",
"dirs",
"esaxx-rs",
"esaxx-rs",
...
@@ -2251,6 +2340,15 @@ version = "0.4.0"
...
@@ -2251,6 +2340,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178"
dependencies = [
"winapi",
]
[[package]]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
version = "0.4.0"
...
...
router/Cargo.toml
View file @
92c1ecd0
...
@@ -2,6 +2,8 @@
...
@@ -2,6 +2,8 @@
name
=
"text-generation-router"
name
=
"text-generation-router"
version
=
"0.1.0"
version
=
"0.1.0"
edition
=
"2021"
edition
=
"2021"
authors
=
[
"Olivier Dehaene"
]
description
=
"Text Generation Webserver"
[lib]
[lib]
path
=
"src/lib.rs"
path
=
"src/lib.rs"
...
@@ -13,6 +15,7 @@ path = "src/main.rs"
...
@@ -13,6 +15,7 @@ path = "src/main.rs"
[dependencies]
[dependencies]
axum
=
{
version
=
"0.5.16"
,
features
=
[
"json"
,
"serde_json"
]
}
axum
=
{
version
=
"0.5.16"
,
features
=
[
"json"
,
"serde_json"
]
}
bloom-inference-client
=
{
path
=
"client"
}
bloom-inference-client
=
{
path
=
"client"
}
clap
=
{
version
=
"4.0.15"
,
features
=
[
"derive"
,
"env"
]
}
futures
=
"0.3.24"
futures
=
"0.3.24"
parking_lot
=
"0.12.1"
parking_lot
=
"0.12.1"
serde
=
"1.0.145"
serde
=
"1.0.145"
...
...
router/src/batcher.rs
View file @
92c1ecd0
...
@@ -27,7 +27,7 @@ impl From<InferError> for (StatusCode, String) {
...
@@ -27,7 +27,7 @@ impl From<InferError> for (StatusCode, String) {
}
}
#[derive(Clone)]
#[derive(Clone)]
pub
(
crate
)
struct
Batcher
{
pub
struct
Batcher
{
db
:
Db
,
db
:
Db
,
shared
:
Arc
<
Shared
>
,
shared
:
Arc
<
Shared
>
,
}
}
...
@@ -37,13 +37,13 @@ struct Shared {
...
@@ -37,13 +37,13 @@ struct Shared {
}
}
impl
Batcher
{
impl
Batcher
{
pub
(
crate
)
fn
new
(
client
:
ShardedClient
)
->
Self
{
pub
(
crate
)
fn
new
(
client
:
ShardedClient
,
max_batch_size
:
usize
)
->
Self
{
let
db
=
Db
::
new
();
let
db
=
Db
::
new
();
let
shared
=
Arc
::
new
(
Shared
{
let
shared
=
Arc
::
new
(
Shared
{
batching_task
:
Notify
::
new
(),
batching_task
:
Notify
::
new
(),
});
});
tokio
::
spawn
(
batching_task
(
client
,
db
.clone
(),
shared
.clone
()));
tokio
::
spawn
(
batching_task
(
max_batch_size
,
client
,
db
.clone
(),
shared
.clone
()));
Self
{
db
,
shared
}
Self
{
db
,
shared
}
}
}
...
@@ -70,40 +70,46 @@ impl Batcher {
...
@@ -70,40 +70,46 @@ impl Batcher {
}
}
}
}
async
fn
batching_task
(
client
:
ShardedClient
,
db
:
Db
,
shared
:
Arc
<
Shared
>
)
{
async
fn
batching_task
(
max_batch_size
:
usize
,
client
:
ShardedClient
,
db
:
Db
,
shared
:
Arc
<
Shared
>
)
{
let
limit_min_batch_size
=
(
max_batch_size
/
2
)
as
u32
;
loop
{
loop
{
shared
.batching_task
.notified
()
.await
;
shared
.batching_task
.notified
()
.await
;
if
let
Some
(
batch
)
=
db
.next_batch
(
32
)
{
if
let
Some
(
batch
)
=
db
.next_batch
(
max_batch_size
)
{
let
request_ids
=
batch
.requests
.iter
()
.map
(|
req
|
req
.id
)
.collect
();
let
request_ids
=
batch
.requests
.iter
()
.map
(|
req
|
req
.id
)
.collect
();
let
mut
cached_batch
=
match
batch
.size
{
let
mut
cached_batch
=
match
batch
.size
{
size
if
size
>
16
=>
{
size
if
size
>
limit_min_batch_size
=>
{
wrap_future
(
client
.generate_until_finished
(
batch
),
request_ids
,
&
db
)
.await
wrap_future
(
client
.generate_until_finished
(
batch
),
request_ids
,
&
db
)
.await
}
}
_
=>
wrap_future
(
client
.generate
(
batch
),
request_ids
,
&
db
)
.await
,
_
=>
wrap_future
(
client
.generate
(
batch
),
request_ids
,
&
db
)
.await
,
};
};
while
let
Some
(
batch
)
=
cached_batch
{
while
let
Some
(
batch
)
=
cached_batch
{
let
batch_size
=
batch
.size
;
let
mut
current_
batch_size
=
batch
.size
;
let
mut
request_ids
:
Vec
<
u64
>
=
batch
.requests
.iter
()
.map
(|
req
|
req
.id
)
.collect
();
let
mut
request_ids
:
Vec
<
u64
>
=
batch
.requests
.iter
()
.map
(|
req
|
req
.id
)
.collect
();
let
mut
batches
=
vec!
[
batch
];
let
mut
batches
=
vec!
[
batch
];
if
batch_size
<=
16
{
if
current_
batch_size
<=
limit_min_batch_size
{
if
let
Some
(
new_batch
)
=
db
.next_batch_minimum_size
(
16
,
48
)
{
if
let
Some
(
new_batch
)
=
db
.next_batch_minimum_size
(
limit_min_batch_size
as
usize
,
max_batch_size
)
{
let
new_batch_request_ids
=
let
new_batch_request_ids
=
new_batch
.requests
.iter
()
.map
(|
req
|
req
.id
)
.collect
();
new_batch
.requests
.iter
()
.map
(|
req
|
req
.id
)
.collect
();
let
new_cached_batch
=
let
new_cached_batch
=
wrap_future
(
client
.generate
(
new_batch
),
new_batch_request_ids
,
&
db
)
wrap_future
(
client
.generate
(
new_batch
),
new_batch_request_ids
,
&
db
)
.await
;
.await
;
if
let
Some
(
new_cached_batch
)
=
new_cached_batch
{
if
let
Some
(
new_cached_batch
)
=
new_cached_batch
{
current_batch_size
+=
new_cached_batch
.size
;
request_ids
.extend
(
new_cached_batch
.requests
.iter
()
.map
(|
req
|
req
.id
));
request_ids
.extend
(
new_cached_batch
.requests
.iter
()
.map
(|
req
|
req
.id
));
batches
.push
(
new_cached_batch
);
batches
.push
(
new_cached_batch
);
}
}
}
}
}
}
cached_batch
=
match
batch_size
{
cached_batch
=
match
current_
batch_size
{
size
if
size
>
16
=>
{
size
if
size
>
limit_min_batch_size
=>
{
wrap_future
(
wrap_future
(
client
.generate_until_finished_with_cache
(
batches
),
client
.generate_until_finished_with_cache
(
batches
),
request_ids
,
request_ids
,
...
...
router/src/lib.rs
View file @
92c1ecd0
mod
batcher
;
mod
batcher
;
mod
db
;
mod
db
;
pub
mod
server
;
mod
validation
;
mod
validation
;
pub
mod
server
;
use
batcher
::
Batcher
;
use
db
::{
Db
,
Entry
};
use
db
::{
Db
,
Entry
};
use
batcher
::
Batcher
;
use
validation
::
Validation
;
use
validation
::
Validation
;
router/src/main.rs
View file @
92c1ecd0
use
bloom_inference_client
::
ShardedClient
;
use
bloom_inference_client
::
ShardedClient
;
use
std
::
net
::
SocketAddr
;
use
std
::
net
::
{
IpAddr
,
Ipv4Addr
,
SocketAddr
}
;
use
text_generation_router
::
server
;
use
text_generation_router
::
server
;
use
tokenizers
::
Tokenizer
;
use
tokenizers
::
Tokenizer
;
use
clap
::
Parser
;
/// App Configuration
#[derive(Parser,
Debug)]
#[clap(author,
version,
about,
long_about
=
None)]
struct
Args
{
#[clap(default_value
=
"32"
,
long,
short,
env)]
max_batch_size
:
usize
,
#[clap(default_value
=
"3000"
,
long,
short,
env)]
port
:
u16
,
#[clap(default_value
=
"/tmp/bloom-inference-0"
,
long,
env)]
shard_uds_path
:
String
,
#[clap(default_value
=
"bigscience/bloom"
,
long,
env)]
tokenizer_name
:
String
,
}
fn
main
()
->
Result
<
(),
std
::
io
::
Error
>
{
fn
main
()
->
Result
<
(),
std
::
io
::
Error
>
{
let
tokenizer
=
Tokenizer
::
from_pretrained
(
"bigscience/bloom"
,
None
)
.unwrap
();
// Get args
let
args
=
Args
::
parse
();
// Pattern match configuration
let
Args
{
max_batch_size
,
port
,
shard_uds_path
,
tokenizer_name
,
}
=
args
;
let
tokenizer
=
Tokenizer
::
from_pretrained
(
tokenizer_name
,
None
)
.unwrap
();
tokio
::
runtime
::
Builder
::
new_multi_thread
()
tokio
::
runtime
::
Builder
::
new_multi_thread
()
.enable_all
()
.enable_all
()
...
@@ -13,7 +39,7 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -13,7 +39,7 @@ fn main() -> Result<(), std::io::Error> {
.block_on
(
async
{
.block_on
(
async
{
tracing_subscriber
::
fmt
::
init
();
tracing_subscriber
::
fmt
::
init
();
let
sharded_client
=
ShardedClient
::
connect_uds
(
"/tmp/bloom-inference-0"
.to_string
()
)
let
sharded_client
=
ShardedClient
::
connect_uds
(
shard_uds_path
)
.await
.await
.expect
(
"Could not connect to server"
);
.expect
(
"Could not connect to server"
);
sharded_client
sharded_client
...
@@ -22,9 +48,9 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -22,9 +48,9 @@ fn main() -> Result<(), std::io::Error> {
.expect
(
"Unable to clear cache"
);
.expect
(
"Unable to clear cache"
);
tracing
::
info!
(
"Connected"
);
tracing
::
info!
(
"Connected"
);
let
addr
=
SocketAddr
::
from
(([
0
,
0
,
0
,
0
],
3000
)
);
let
addr
=
SocketAddr
::
new
(
IpAddr
::
V4
(
Ipv4Addr
::
new
(
0
,
0
,
0
,
0
)),
port
);
server
::
run
(
sharded_client
,
tokenizer
,
addr
)
.await
;
server
::
run
(
max_batch_size
,
sharded_client
,
tokenizer
,
addr
)
.await
;
Ok
(())
Ok
(())
})
})
}
}
router/src/server.rs
View file @
92c1ecd0
...
@@ -64,7 +64,7 @@ pub(crate) struct GenerateRequest {
...
@@ -64,7 +64,7 @@ pub(crate) struct GenerateRequest {
#[instrument(skip(state),
fields(time,
time_per_token))]
#[instrument(skip(state),
fields(time,
time_per_token))]
async
fn
liveness
(
state
:
Extension
<
ServerState
>
)
->
Result
<
(),
(
StatusCode
,
String
)
>
{
async
fn
liveness
(
state
:
Extension
<
ServerState
>
)
->
Result
<
(),
(
StatusCode
,
String
)
>
{
state
state
.
inf
er
.
batch
er
.infer
(
.infer
(
1
,
1
,
GenerateRequest
{
GenerateRequest
{
...
@@ -97,7 +97,7 @@ async fn generate(
...
@@ -97,7 +97,7 @@ async fn generate(
})
})
.await
?
;
.await
?
;
let
generated_text
=
state
.
inf
er
.infer
(
input_length
,
validated_request
)
.await
?
;
let
generated_text
=
state
.
batch
er
.infer
(
input_length
,
validated_request
)
.await
?
;
tracing
::
Span
::
current
()
.record
(
"time"
,
format!
(
"{:?}"
,
start
.elapsed
()));
tracing
::
Span
::
current
()
.record
(
"time"
,
format!
(
"{:?}"
,
start
.elapsed
()));
tracing
::
Span
::
current
()
.record
(
tracing
::
Span
::
current
()
.record
(
...
@@ -114,18 +114,14 @@ async fn generate(
...
@@ -114,18 +114,14 @@ async fn generate(
#[derive(Clone)]
#[derive(Clone)]
struct
ServerState
{
struct
ServerState
{
validation
:
Validation
,
validation
:
Validation
,
inf
er
:
Batcher
,
batch
er
:
Batcher
,
}
}
pub
async
fn
run
(
client
:
ShardedClient
,
tokenizer
:
Tokenizer
,
addr
:
SocketAddr
)
{
pub
async
fn
run
(
max_batch_size
:
usize
,
client
:
ShardedClient
,
tokenizer
:
Tokenizer
,
addr
:
SocketAddr
)
{
client
.clear_cache
()
.await
.expect
(
"Unable to clear cache"
);
let
batcher
=
Batcher
::
new
(
client
,
max_batch_size
);
tracing
::
info!
(
"Connected"
);
let
infer
=
Batcher
::
new
(
client
);
let
validation
=
Validation
::
new
(
tokenizer
);
let
validation
=
Validation
::
new
(
tokenizer
);
let
shared_state
=
ServerState
{
validation
,
inf
er
};
let
shared_state
=
ServerState
{
validation
,
batch
er
};
let
app
=
Router
::
new
()
let
app
=
Router
::
new
()
.route
(
"/generate"
,
post
(
generate
))
.route
(
"/generate"
,
post
(
generate
))
...
...
router/src/validation.rs
View file @
92c1ecd0
...
@@ -14,7 +14,7 @@ pub enum ValidationError {
...
@@ -14,7 +14,7 @@ pub enum ValidationError {
TopK
,
TopK
,
#[error(
"Max New Tokens must be < 512"
)]
#[error(
"Max New Tokens must be < 512"
)]
MaxNewTokens
,
MaxNewTokens
,
#[error(
"Inputs must have less than
512
tokens. Given: {0}"
)]
#[error(
"Inputs must have less than
1000
tokens. Given: {0}"
)]
InputLength
(
usize
),
InputLength
(
usize
),
}
}
...
@@ -30,7 +30,7 @@ type ValidationRequest = (
...
@@ -30,7 +30,7 @@ type ValidationRequest = (
);
);
#[derive(Debug,
Clone)]
#[derive(Debug,
Clone)]
pub
(
crate
)
struct
Validation
{
pub
struct
Validation
{
sender
:
mpsc
::
Sender
<
ValidationRequest
>
,
sender
:
mpsc
::
Sender
<
ValidationRequest
>
,
}
}
...
@@ -81,7 +81,7 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
...
@@ -81,7 +81,7 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
let
inputs
=
tokenizer
.encode
(
request
.inputs
.clone
(),
false
)
.unwrap
();
let
inputs
=
tokenizer
.encode
(
request
.inputs
.clone
(),
false
)
.unwrap
();
let
input_length
=
inputs
.len
();
let
input_length
=
inputs
.len
();
if
input_length
>
512
{
if
input_length
>
1000
{
response_tx
response_tx
.send
(
Err
(
ValidationError
::
InputLength
(
input_length
)))
.send
(
Err
(
ValidationError
::
InputLength
(
input_length
)))
.unwrap_or
(());
.unwrap_or
(());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment