Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
39df4d99
Commit
39df4d99
authored
Oct 11, 2022
by
Olivier Dehaene
Browse files
Use axum
parent
e86ecbac
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
121 additions
and
148 deletions
+121
-148
router/Cargo.lock
router/Cargo.lock
+85
-117
router/Cargo.toml
router/Cargo.toml
+1
-3
router/client/Cargo.toml
router/client/Cargo.toml
+3
-3
router/src/main.rs
router/src/main.rs
+4
-4
router/src/server.rs
router/src/server.rs
+28
-21
No files found.
router/Cargo.lock
View file @
39df4d99
...
@@ -81,6 +81,53 @@ version = "1.1.0"
...
@@ -81,6 +81,53 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "axum"
version = "0.5.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9e3356844c4d6a6d6467b8da2cffb4a2820be256f50a3a386c9d152bab31043"
dependencies = [
"async-trait",
"axum-core",
"bitflags",
"bytes",
"futures-util",
"http",
"http-body",
"hyper",
"itoa",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower",
"tower-http",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-core"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9f0c0a60006f2a293d82d571f635042a72edf927539b7685bd62d361963839b"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http",
"http-body",
"mime",
"tower-layer",
"tower-service",
]
[[package]]
[[package]]
name = "base64"
name = "base64"
version = "0.13.0"
version = "0.13.0"
...
@@ -106,10 +153,10 @@ dependencies = [
...
@@ -106,10 +153,10 @@ dependencies = [
name = "bloom-inference"
name = "bloom-inference"
version = "0.1.0"
version = "0.1.0"
dependencies = [
dependencies = [
"axum",
"bloom-inference-client",
"bloom-inference-client",
"futures",
"futures",
"parking_lot",
"parking_lot",
"poem",
"serde",
"serde",
"serde_json",
"serde_json",
"tokenizers",
"tokenizers",
...
@@ -661,31 +708,6 @@ version = "0.12.3"
...
@@ -661,31 +708,6 @@ version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]]
name = "headers"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3e372db8e5c0d213e0cd0b9be18be2aca3d44cf2fe30a9d46a65581cd454584"
dependencies = [
"base64",
"bitflags",
"bytes",
"headers-core",
"http",
"httpdate",
"mime",
"sha1",
]
[[package]]
name = "headers-core"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429"
dependencies = [
"http",
]
[[package]]
[[package]]
name = "heck"
name = "heck"
version = "0.3.3"
version = "0.3.3"
...
@@ -726,6 +748,12 @@ dependencies = [
...
@@ -726,6 +748,12 @@ dependencies = [
"pin-project-lite",
"pin-project-lite",
]
]
[[package]]
name = "http-range-header"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29"
[[package]]
[[package]]
name = "httparse"
name = "httparse"
version = "1.8.0"
version = "1.8.0"
...
@@ -941,6 +969,12 @@ version = "0.1.2"
...
@@ -941,6 +969,12 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26a8d2502d5aa4d411ef494ba7470eb299f05725179ce3b5de77aa01a9ffdea"
checksum = "f26a8d2502d5aa4d411ef494ba7470eb299f05725179ce3b5de77aa01a9ffdea"
[[package]]
name = "matchit"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
[[package]]
[[package]]
name = "memchr"
name = "memchr"
version = "2.5.0"
version = "2.5.0"
...
@@ -1201,65 +1235,12 @@ version = "0.3.25"
...
@@ -1201,65 +1235,12 @@ version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae"
checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae"
[[package]]
name = "poem"
version = "1.3.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2992ba72908e36200671c0f3a692992ced894b3b2bbe2b2dc6dfbffea6e2c85a"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"headers",
"http",
"hyper",
"mime",
"parking_lot",
"percent-encoding",
"pin-project-lite",
"poem-derive",
"regex",
"rfc7239",
"serde",
"serde_json",
"serde_urlencoded",
"smallvec",
"thiserror",
"tokio",
"tokio-stream",
"tokio-util 0.7.4",
"tracing",
]
[[package]]
name = "poem-derive"
version = "1.3.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f535d4331a22610b98ca48f98bae9bda0c654da89b9ae10a1830fa9edfd8f36"
dependencies = [
"proc-macro-crate",
"proc-macro2",
"quote",
"syn",
]
[[package]]
[[package]]
name = "ppv-lite86"
name = "ppv-lite86"
version = "0.2.16"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
[[package]]
name = "proc-macro-crate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eda0fc3b0fb7c975631757e14d9049da17374063edb6ebbcbc54d880d4fe94e9"
dependencies = [
"once_cell",
"thiserror",
"toml",
]
[[package]]
[[package]]
name = "proc-macro2"
name = "proc-macro2"
version = "1.0.46"
version = "1.0.46"
...
@@ -1479,15 +1460,6 @@ dependencies = [
...
@@ -1479,15 +1460,6 @@ dependencies = [
"winreg",
"winreg",
]
]
[[package]]
name = "rfc7239"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "087317b3cf7eb481f13bd9025d729324b7cd068d6f470e2d76d049e191f5ba47"
dependencies = [
"uncased",
]
[[package]]
[[package]]
name = "ryu"
name = "ryu"
version = "1.0.11"
version = "1.0.11"
...
@@ -1576,17 +1548,6 @@ dependencies = [
...
@@ -1576,17 +1548,6 @@ dependencies = [
"serde",
"serde",
]
]
[[package]]
name = "sha1"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
[[package]]
name = "sha2"
name = "sha2"
version = "0.10.6"
version = "0.10.6"
...
@@ -1667,6 +1628,12 @@ dependencies = [
...
@@ -1667,6 +1628,12 @@ dependencies = [
"unicode-ident",
"unicode-ident",
]
]
[[package]]
name = "sync_wrapper"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8"
[[package]]
[[package]]
name = "tar"
name = "tar"
version = "0.4.38"
version = "0.4.38"
...
@@ -1890,15 +1857,6 @@ dependencies = [
...
@@ -1890,15 +1857,6 @@ dependencies = [
"tracing",
"tracing",
]
]
[[package]]
name = "toml"
version = "0.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d82e1a7758622a465f8cee077614c73484dac5b836c02ff6a40d5d1010324d7"
dependencies = [
"serde",
]
[[package]]
[[package]]
name = "tonic"
name = "tonic"
version = "0.6.2"
version = "0.6.2"
...
@@ -1962,6 +1920,25 @@ dependencies = [
...
@@ -1962,6 +1920,25 @@ dependencies = [
"tracing",
"tracing",
]
]
[[package]]
name = "tower-http"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c530c8675c1dbf98facee631536fa116b5fb6382d7dd6dc1b118d970eafe3ba"
dependencies = [
"bitflags",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-range-header",
"pin-project-lite",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
[[package]]
name = "tower-layer"
name = "tower-layer"
version = "0.3.1"
version = "0.3.1"
...
@@ -2065,15 +2042,6 @@ version = "1.15.0"
...
@@ -2065,15 +2042,6 @@ version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
[[package]]
name = "uncased"
version = "0.9.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09b01702b0fd0b3fadcf98e098780badda8742d4f4a7676615cad90e8ac73622"
dependencies = [
"version_check",
]
[[package]]
[[package]]
name = "unicode-bidi"
name = "unicode-bidi"
version = "0.3.8"
version = "0.3.8"
...
...
router/Cargo.toml
View file @
39df4d99
...
@@ -3,13 +3,11 @@ name = "bloom-inference"
...
@@ -3,13 +3,11 @@ name = "bloom-inference"
version
=
"0.1.0"
version
=
"0.1.0"
edition
=
"2021"
edition
=
"2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
[dependencies]
axum
=
{
version
=
"0.5.16"
,
features
=
[
"json"
,
"serde_json"
]
}
bloom-inference-client
=
{
path
=
"client"
}
bloom-inference-client
=
{
path
=
"client"
}
futures
=
"0.3.24"
futures
=
"0.3.24"
parking_lot
=
"0.12.1"
parking_lot
=
"0.12.1"
poem
=
"1.3.45"
serde
=
"1.0.145"
serde
=
"1.0.145"
serde_json
=
"1.0.85"
serde_json
=
"1.0.85"
tokenizers
=
"0.13.0"
tokenizers
=
"0.13.0"
...
...
router/client/Cargo.toml
View file @
39df4d99
...
@@ -4,12 +4,12 @@ version = "0.1.0"
...
@@ -4,12 +4,12 @@ version = "0.1.0"
edition
=
"2021"
edition
=
"2021"
[dependencies]
[dependencies]
futures
=
"0.3
.24
"
futures
=
"
^
0.3"
#grpc-error-details = { path = "../../grpc-error-details" }
#grpc-error-details = { path = "../../grpc-error-details" }
#grpc-metadata = { path = "../../grpc-metadata" }
#grpc-metadata = { path = "../../grpc-metadata" }
prost
=
"^0.9"
prost
=
"^0.9"
thiserror
=
"1.0
.37
"
thiserror
=
"
^
1.0"
tokio
=
{
version
=
"1.21
.2
"
,
features
=
["sync"]
}
tokio
=
{
version
=
"
^
1.21"
,
features
=
["sync"]
}
tonic
=
"^0.6"
tonic
=
"^0.6"
tower
=
"^0.4"
tower
=
"^0.4"
tracing
=
"^0.1"
tracing
=
"^0.1"
...
...
router/src/main.rs
View file @
39df4d99
use
std
::
net
::
SocketAddr
;
use
bloom_inference_client
::
ShardedClient
;
use
bloom_inference_client
::
ShardedClient
;
use
poem
::
listener
::
TcpListener
;
use
std
::
time
::
Duration
;
use
std
::
time
::
Duration
;
use
tokenizers
::
Tokenizer
;
use
tokenizers
::
Tokenizer
;
...
@@ -37,9 +37,9 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -37,9 +37,9 @@ fn main() -> Result<(), std::io::Error> {
.expect
(
"Unable to clear cache"
);
.expect
(
"Unable to clear cache"
);
tracing
::
info!
(
"Connected"
);
tracing
::
info!
(
"Connected"
);
let
addr
=
"127.0.0.1:3000"
.to_string
();
let
addr
=
SocketAddr
::
from
(([
127
,
0
,
0
,
1
],
3000
));
let
listener
=
TcpListener
::
bind
(
addr
);
server
::
run
(
sharded_client
,
tokenizer
,
listener
)
.await
server
::
run
(
sharded_client
,
tokenizer
,
addr
)
.await
;
Ok
(())
})
})
}
}
router/src/server.rs
View file @
39df4d99
use
std
::
net
::
SocketAddr
;
use
axum
::{
Router
,
Json
};
use
axum
::
http
::
StatusCode
;
use
axum
::
extract
::
Extension
;
use
axum
::
routing
::
post
;
use
crate
::{
Batcher
,
ShardedClient
,
Validation
};
use
crate
::{
Batcher
,
ShardedClient
,
Validation
};
use
poem
::
http
::
StatusCode
;
use
poem
::
listener
::
TcpListener
;
use
poem
::
middleware
::
AddData
;
use
poem
::
web
::{
Data
,
Json
};
use
poem
::{
handler
,
post
,
EndpointExt
,
Route
,
Server
};
use
serde
::
Deserialize
;
use
serde
::
Deserialize
;
use
tokenizers
::
Tokenizer
;
use
tokenizers
::
Tokenizer
;
use
tokio
::
time
::
Instant
;
use
tokio
::
time
::
Instant
;
...
@@ -60,26 +60,24 @@ pub(crate) struct GenerateRequest {
...
@@ -60,26 +60,24 @@ pub(crate) struct GenerateRequest {
pub
parameters
:
GenerateParameters
,
pub
parameters
:
GenerateParameters
,
}
}
#[handler]
#[instrument(skip(state),
fields(time,
time_per_token))]
#[instrument(skip(validation,
infer),
fields(time,
time_per_token))]
async
fn
generate
(
async
fn
generate
(
validation
:
Data
<&
Validation
>
,
state
:
Extension
<
ServerState
>
,
infer
:
Data
<&
Batcher
>
,
req
:
Json
<
GenerateRequest
>
,
req
:
Json
<
GenerateRequest
>
,
)
->
poem
::
Result
<
Json
<
serde_json
::
Value
>>
{
)
->
Result
<
Json
<
serde_json
::
Value
>
,
StatusCode
>
{
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
let
(
input_length
,
validated_request
)
=
match
validation
let
(
input_length
,
validated_request
)
=
match
state
.
validation
.validate
(
GenerateRequest
{
.validate
(
GenerateRequest
{
inputs
:
req
.inputs
.clone
(),
inputs
:
req
.inputs
.clone
(),
parameters
:
req
.parameters
.clone
(),
parameters
:
req
.parameters
.clone
(),
})
})
.await
{
.await
{
Ok
(
result
)
=>
result
,
Ok
(
result
)
=>
result
,
Err
(
_
)
=>
return
Err
(
poem
::
Error
::
from_status
(
StatusCode
::
INTERNAL_SERVER_ERROR
)
)
Err
(
_
)
=>
return
Err
(
StatusCode
::
INTERNAL_SERVER_ERROR
)
};
};
let
output
=
infer
.infer
(
input_length
,
validated_request
)
.await
;
let
output
=
state
.
infer
.infer
(
input_length
,
validated_request
)
.await
;
match
output
{
match
output
{
Ok
(
generated_text
)
=>
{
Ok
(
generated_text
)
=>
{
...
@@ -94,15 +92,21 @@ async fn generate(
...
@@ -94,15 +92,21 @@ async fn generate(
"generated_text"
:
generated_text
,
"generated_text"
:
generated_text
,
})))
})))
}
}
Err
(
_
)
=>
Err
(
poem
::
Error
::
from_status
(
StatusCode
::
INTERNAL_SERVER_ERROR
)
)
,
Err
(
_
)
=>
Err
(
StatusCode
::
INTERNAL_SERVER_ERROR
),
}
}
}
}
#[derive(Clone)]
struct
ServerState
{
validation
:
Validation
,
infer
:
Batcher
,
}
pub
async
fn
run
(
pub
async
fn
run
(
client
:
ShardedClient
,
client
:
ShardedClient
,
tokenizer
:
Tokenizer
,
tokenizer
:
Tokenizer
,
listener
:
TcpListener
<
String
>
,
addr
:
SocketAddr
,
)
->
Result
<
(),
std
::
io
::
Error
>
{
)
{
client
.clear_cache
()
.await
.expect
(
"Unable to clear cache"
);
client
.clear_cache
()
.await
.expect
(
"Unable to clear cache"
);
tracing
::
info!
(
"Connected"
);
tracing
::
info!
(
"Connected"
);
...
@@ -110,10 +114,13 @@ pub async fn run(
...
@@ -110,10 +114,13 @@ pub async fn run(
let
validation
=
Validation
::
new
(
tokenizer
);
let
validation
=
Validation
::
new
(
tokenizer
);
let
app
=
Route
::
new
()
let
shared_state
=
ServerState
{
.at
(
"/generate"
,
post
(
generate
))
validation
,
.with
(
AddData
::
new
(
validation
))
infer
,
.with
(
AddData
::
new
(
infer
));
};
let
app
=
Router
::
new
()
.route
(
"/generate"
,
post
(
generate
))
.layer
(
Extension
(
shared_state
));
Server
::
new
(
listener
)
.run
(
app
)
.await
axum
::
Server
::
bind
(
&
addr
)
.serve
(
app
.into_make_service
())
.await
.unwrap
();
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment