Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
5e5d8766
You need to sign in or sign up before continuing.
Commit
5e5d8766
authored
Oct 17, 2022
by
Olivier Dehaene
Browse files
feat: Improve error handling
parent
00e6ce44
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
270 additions
and
176 deletions
+270
-176
Dockerfile
Dockerfile
+2
-1
README.md
README.md
+0
-1
aml/deployment.yaml
aml/deployment.yaml
+7
-7
router/Cargo.lock
router/Cargo.lock
+17
-16
router/Cargo.toml
router/Cargo.toml
+9
-1
router/client/src/client.rs
router/client/src/client.rs
+13
-21
router/client/src/lib.rs
router/client/src/lib.rs
+13
-9
router/client/src/sharded_client.rs
router/client/src/sharded_client.rs
+10
-13
router/src/batcher.rs
router/src/batcher.rs
+34
-10
router/src/db.rs
router/src/db.rs
+32
-27
router/src/lib.rs
router/src/lib.rs
+8
-0
router/src/main.rs
router/src/main.rs
+4
-19
router/src/server.rs
router/src/server.rs
+23
-36
router/src/validation.rs
router/src/validation.rs
+35
-7
run.sh
run.sh
+13
-4
server/bloom_inference/cli.py
server/bloom_inference/cli.py
+42
-0
server/bloom_inference/server.py
server/bloom_inference/server.py
+2
-4
server/bloom_inference/utils.py
server/bloom_inference/utils.py
+3
-0
server/pyproject.toml
server/pyproject.toml
+3
-0
No files found.
Dockerfile
View file @
5e5d8766
...
@@ -18,6 +18,7 @@ ENV LANG=C.UTF-8 \
...
@@ -18,6 +18,7 @@ ENV LANG=C.UTF-8 \
MODEL_NAME=bigscience/bloom \
MODEL_NAME=bigscience/bloom \
NUM_GPUS=8 \
NUM_GPUS=8 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NCCL_ASYNC_ERROR_HANDLING=1 \
CUDA_HOME=/usr/local/cuda \
CUDA_HOME=/usr/local/cuda \
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
CONDA_DEFAULT_ENV=text-generation \
CONDA_DEFAULT_ENV=text-generation \
...
@@ -51,7 +52,7 @@ RUN cd server && \
...
@@ -51,7 +52,7 @@ RUN cd server && \
/opt/miniconda/envs/text-generation/bin/pip
install
.
--no-cache-dir
/opt/miniconda/envs/text-generation/bin/pip
install
.
--no-cache-dir
# Install router
# Install router
COPY
--from=builder /usr/local/cargo/bin/
bloom-inference /usr/local/bin/bloom-inference
COPY
--from=builder /usr/local/cargo/bin/
text-generation-router /usr/local/bin/text-generation-router
COPY
run.sh .
COPY
run.sh .
RUN
chmod
+x run.sh
RUN
chmod
+x run.sh
...
...
README.md
View file @
5e5d8766
...
@@ -48,5 +48,4 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire
...
@@ -48,5 +48,4 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire
-
[ ] Add tests
-
[ ] Add tests
-
[ ] Add shutdown logic in router and server
-
[ ] Add shutdown logic in router and server
-
[ ] Improve multi-processing logic in server
-
[ ] Improve multi-processing logic in server
-
[ ] Improve error handling everywhere
-
[ ] Improve past key layer indexing?
-
[ ] Improve past key layer indexing?
\ No newline at end of file
aml/deployment.yaml
View file @
5e5d8766
...
@@ -8,7 +8,7 @@ environment_variables:
...
@@ -8,7 +8,7 @@ environment_variables:
MODEL_NAME
:
bigscience/bloom
MODEL_NAME
:
bigscience/bloom
NUM_GPUS
:
8
NUM_GPUS
:
8
environment
:
environment
:
image
:
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.
1
image
:
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.
3
inference_config
:
inference_config
:
liveness_route
:
liveness_route
:
port
:
3000
port
:
3000
...
@@ -24,15 +24,15 @@ request_settings:
...
@@ -24,15 +24,15 @@ request_settings:
request_timeout_ms
:
90000
request_timeout_ms
:
90000
max_concurrent_requests_per_instance
:
256
max_concurrent_requests_per_instance
:
256
liveness_probe
:
liveness_probe
:
initial_delay
:
3
00
initial_delay
:
6
00
timeout
:
20
timeout
:
20
period
:
6
0
period
:
12
0
success_threshold
:
1
success_threshold
:
1
failure_threshold
:
60
failure_threshold
:
3
readiness_probe
:
readiness_probe
:
initial_delay
:
3
00
initial_delay
:
6
00
timeout
:
20
timeout
:
20
period
:
6
0
period
:
12
0
success_threshold
:
1
success_threshold
:
1
failure_threshold
:
60
failure_threshold
:
3
instance_count
:
1
instance_count
:
1
router/Cargo.lock
View file @
5e5d8766
...
@@ -149,22 +149,6 @@ dependencies = [
...
@@ -149,22 +149,6 @@ dependencies = [
"generic-array",
"generic-array",
]
]
[[package]]
name = "bloom-inference"
version = "0.1.0"
dependencies = [
"axum",
"bloom-inference-client",
"futures",
"parking_lot",
"serde",
"serde_json",
"tokenizers",
"tokio",
"tracing",
"tracing-subscriber",
]
[[package]]
[[package]]
name = "bloom-inference-client"
name = "bloom-inference-client"
version = "0.1.0"
version = "0.1.0"
...
@@ -1669,6 +1653,23 @@ dependencies = [
...
@@ -1669,6 +1653,23 @@ dependencies = [
"winapi",
"winapi",
]
]
[[package]]
name = "text-generation-router"
version = "0.1.0"
dependencies = [
"axum",
"bloom-inference-client",
"futures",
"parking_lot",
"serde",
"serde_json",
"thiserror",
"tokenizers",
"tokio",
"tracing",
"tracing-subscriber",
]
[[package]]
[[package]]
name = "textwrap"
name = "textwrap"
version = "0.11.0"
version = "0.11.0"
...
...
router/Cargo.toml
View file @
5e5d8766
[package]
[package]
name
=
"
bloom-inference
"
name
=
"
text-generation-router
"
version
=
"0.1.0"
version
=
"0.1.0"
edition
=
"2021"
edition
=
"2021"
[lib]
path
=
"src/lib.rs"
[[bin]]
name
=
"text-generation-router"
path
=
"src/main.rs"
[dependencies]
[dependencies]
axum
=
{
version
=
"0.5.16"
,
features
=
[
"json"
,
"serde_json"
]
}
axum
=
{
version
=
"0.5.16"
,
features
=
[
"json"
,
"serde_json"
]
}
bloom-inference-client
=
{
path
=
"client"
}
bloom-inference-client
=
{
path
=
"client"
}
...
@@ -10,6 +17,7 @@ futures = "0.3.24"
...
@@ -10,6 +17,7 @@ futures = "0.3.24"
parking_lot
=
"0.12.1"
parking_lot
=
"0.12.1"
serde
=
"1.0.145"
serde
=
"1.0.145"
serde_json
=
"1.0.85"
serde_json
=
"1.0.85"
thiserror
=
"1.0.37"
tokenizers
=
"0.13.0"
tokenizers
=
"0.13.0"
tokio
=
{
version
=
"1.21.1"
,
features
=
[
"rt-multi-thread"
,
"parking_lot"
,
"sync"
]
}
tokio
=
{
version
=
"1.21.1"
,
features
=
[
"rt-multi-thread"
,
"parking_lot"
,
"sync"
]
}
tracing
=
"0.1.36"
tracing
=
"0.1.36"
...
...
router/client/src/client.rs
View file @
5e5d8766
use
crate
::
pb
::
generate
::
v1
::
text_generation_service_client
::
TextGenerationServiceClient
;
use
crate
::
pb
::
generate
::
v1
::
text_generation_service_client
::
TextGenerationServiceClient
;
use
crate
::
pb
::
generate
::
v1
::
*
;
use
crate
::
pb
::
generate
::
v1
::
*
;
use
crate
::
Result
;
use
crate
::
Result
;
use
std
::
time
::
Duration
;
use
tonic
::
transport
::{
Channel
,
Uri
};
use
tonic
::
transport
::{
Channel
,
Uri
};
use
tower
::
timeout
::
Timeout
;
use
tracing
::
*
;
use
tracing
::
*
;
/// BLOOM Inference gRPC client
/// BLOOM Inference gRPC client
#[derive(Clone)]
#[derive(Clone)]
pub
struct
Client
{
pub
struct
Client
{
stub
:
TextGenerationServiceClient
<
Timeout
<
Channel
>
>
,
stub
:
TextGenerationServiceClient
<
Channel
>
,
}
}
impl
Client
{
impl
Client
{
/// Returns a client connected to the given url. Requests exceeding timeout will fail.
/// Returns a client connected to the given url
pub
async
fn
connect
(
uri
:
Uri
,
timeout
:
Duration
)
->
Self
{
pub
async
fn
connect
(
uri
:
Uri
)
->
Result
<
Self
>
{
let
channel
=
Channel
::
builder
(
uri
)
let
channel
=
Channel
::
builder
(
uri
)
.connect
()
.await
?
;
.connect
()
.await
.expect
(
"Transport error"
);
let
timeout_channel
=
Timeout
::
new
(
channel
,
timeout
);
Self
{
Ok
(
Self
{
stub
:
TextGenerationServiceClient
::
new
(
timeout_
channel
),
stub
:
TextGenerationServiceClient
::
new
(
channel
),
}
}
)
}
}
/// Returns a client connected to the given unix socket
. Requests exceeding timeout will fail.
/// Returns a client connected to the given unix socket
pub
async
fn
connect_uds
(
path
:
String
,
timeout
:
Duration
)
->
Self
{
pub
async
fn
connect_uds
(
path
:
String
)
->
Result
<
Self
>
{
let
channel
=
Channel
::
from_shared
(
"http://[::]:50051"
.to_string
())
let
channel
=
Channel
::
from_shared
(
"http://[::]:50051"
.to_string
())
.unwrap
()
.unwrap
()
.connect_with_connector
(
tower
::
service_fn
(
move
|
_
:
Uri
|
{
.connect_with_connector
(
tower
::
service_fn
(
move
|
_
:
Uri
|
{
tokio
::
net
::
UnixStream
::
connect
(
path
.clone
())
tokio
::
net
::
UnixStream
::
connect
(
path
.clone
())
}))
}))
.await
.await
?
;
.expect
(
"Transport error"
);
let
timeout_channel
=
Timeout
::
new
(
channel
,
timeout
);
Self
{
Ok
(
Self
{
stub
:
TextGenerationServiceClient
::
new
(
timeout_
channel
),
stub
:
TextGenerationServiceClient
::
new
(
channel
),
}
}
)
}
}
#[instrument(skip(self))]
#[instrument(skip(self))]
...
...
router/client/src/lib.rs
View file @
5e5d8766
...
@@ -8,22 +8,26 @@ pub use client::Client;
...
@@ -8,22 +8,26 @@ pub use client::Client;
pub
use
pb
::
generate
::
v1
::{
Batch
,
GeneratedText
,
LogitsWarperParameters
,
Request
};
pub
use
pb
::
generate
::
v1
::{
Batch
,
GeneratedText
,
LogitsWarperParameters
,
Request
};
pub
use
sharded_client
::
ShardedClient
;
pub
use
sharded_client
::
ShardedClient
;
use
thiserror
::
Error
;
use
thiserror
::
Error
;
pub
use
tonic
::
transport
::
Uri
;
pub
use
tonic
::
transport
;
use
tonic
::
Status
;
use
tonic
::
Status
;
#[derive(Error,
Debug,
Clone)]
#[derive(Error,
Debug,
Clone)]
#[error(
"Text generation client error: {msg:?}"
)]
pub
enum
ClientError
{
pub
struct
ClientError
{
#[error(
"Could not connect to Text Generation server: {0:?}"
)]
msg
:
String
,
Connection
(
String
),
// source: Status,
#[error(
"Server error: {0:?}"
)]
Generation
(
String
),
}
}
impl
From
<
Status
>
for
ClientError
{
impl
From
<
Status
>
for
ClientError
{
fn
from
(
err
:
Status
)
->
Self
{
fn
from
(
err
:
Status
)
->
Self
{
Self
{
Self
::
Generation
(
err
.to_string
())
msg
:
err
.to_string
(),
}
// source: err,
}
}
impl
From
<
transport
::
Error
>
for
ClientError
{
fn
from
(
err
:
transport
::
Error
)
->
Self
{
Self
::
Connection
(
err
.to_string
())
}
}
}
}
...
...
router/client/src/sharded_client.rs
View file @
5e5d8766
use
crate
::
Result
;
use
crate
::
Result
;
use
crate
::{
Batch
,
Client
,
GeneratedText
};
use
crate
::{
Batch
,
Client
,
GeneratedText
};
use
futures
::
future
::
join_all
;
use
futures
::
future
::
join_all
;
use
std
::
time
::
Duration
;
use
tokio
::
sync
::{
broadcast
,
mpsc
};
use
tokio
::
sync
::{
broadcast
,
mpsc
};
use
tonic
::
transport
::
Uri
;
use
tonic
::
transport
::
Uri
;
...
@@ -69,24 +68,22 @@ impl ShardedClient {
...
@@ -69,24 +68,22 @@ impl ShardedClient {
Self
{
request_tx
}
Self
{
request_tx
}
}
}
async
fn
from_master_client
(
mut
master_client
:
Client
)
->
Self
{
async
fn
from_master_client
(
mut
master_client
:
Client
)
->
Result
<
Self
>
{
let
uris
=
master_client
.service_discovery
()
.await
.unwrap
();
let
uris
=
master_client
.service_discovery
()
.await
.unwrap
();
let
futures
=
uris
let
futures
=
uris
.into_iter
()
.map
(|
path
|
Client
::
connect_uds
(
path
));
.into_iter
()
let
clients
:
Result
<
Vec
<
Client
>>
=
join_all
(
futures
)
.await
.into_iter
()
.collect
();
.map
(|
path
|
Client
::
connect_uds
(
path
,
Duration
::
from_secs
(
5
)));
Ok
(
Self
::
new
(
clients
?
))
let
clients
=
join_all
(
futures
)
.await
;
Self
::
new
(
clients
)
}
}
/// Returns a client connected to the given url
. Requests exceeding timeout will fail.
/// Returns a client connected to the given url
pub
async
fn
connect
(
uri
:
Uri
,
timeout
:
Duration
)
->
Self
{
pub
async
fn
connect
(
uri
:
Uri
)
->
Result
<
Self
>
{
let
master_client
=
Client
::
connect
(
uri
,
timeout
)
.await
;
let
master_client
=
Client
::
connect
(
uri
)
.await
?
;
Self
::
from_master_client
(
master_client
)
.await
Self
::
from_master_client
(
master_client
)
.await
}
}
/// Returns a client connected to the given unix socket
. Requests exceeding timeout will fail.
/// Returns a client connected to the given unix socket
pub
async
fn
connect_uds
(
path
:
String
,
timeout
:
Duration
)
->
Self
{
pub
async
fn
connect_uds
(
path
:
String
)
->
Result
<
Self
>
{
let
master_client
=
Client
::
connect_uds
(
path
,
timeout
)
.await
;
let
master_client
=
Client
::
connect_uds
(
path
)
.await
?
;
Self
::
from_master_client
(
master_client
)
.await
Self
::
from_master_client
(
master_client
)
.await
}
}
...
...
router/src/batcher.rs
View file @
5e5d8766
use
crate
::
server
::
GenerateRequest
;
use
crate
::
server
::
GenerateRequest
;
use
crate
::
Db
;
use
crate
::{
Db
,
Entry
};
use
axum
::
http
::
StatusCode
;
use
bloom_inference_client
::{
Batch
,
ClientError
,
GeneratedText
,
ShardedClient
};
use
bloom_inference_client
::{
Batch
,
ClientError
,
GeneratedText
,
ShardedClient
};
use
std
::
future
::
Future
;
use
std
::
future
::
Future
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
thiserror
::
Error
;
use
tokio
::
sync
::{
oneshot
,
Notify
};
use
tokio
::
sync
::{
oneshot
,
Notify
};
const
MAX_LENGTH
:
usize
=
128
;
const
MAX_LENGTH
:
usize
=
128
;
pub
struct
InferError
{}
#[derive(Debug,
Error)]
pub
enum
InferError
{
#[error(
"Request failed during generation: {0}"
)]
GenerationError
(
String
),
#[error(
"Model is overloaded"
)]
Overloaded
,
}
impl
From
<
InferError
>
for
(
StatusCode
,
String
)
{
fn
from
(
err
:
InferError
)
->
Self
{
match
err
{
InferError
::
GenerationError
(
_
)
=>
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
err
.to_string
()),
InferError
::
Overloaded
=>
(
StatusCode
::
TOO_MANY_REQUESTS
,
err
.to_string
()),
}
}
}
#[derive(Clone)]
#[derive(Clone)]
pub
(
crate
)
struct
Batcher
{
pub
(
crate
)
struct
Batcher
{
...
@@ -37,14 +54,18 @@ impl Batcher {
...
@@ -37,14 +54,18 @@ impl Batcher {
request
:
GenerateRequest
,
request
:
GenerateRequest
,
)
->
Result
<
String
,
InferError
>
{
)
->
Result
<
String
,
InferError
>
{
if
self
.db
.len
()
>
MAX_LENGTH
{
if
self
.db
.len
()
>
MAX_LENGTH
{
return
Err
(
InferError
{}
);
return
Err
(
InferError
::
Overloaded
);
}
}
let
(
request_tx
,
request_rx
)
=
oneshot
::
channel
();
let
(
request_tx
,
request_rx
)
=
oneshot
::
channel
();
self
.db
.append
(
input_length
,
request
,
request_tx
);
self
.db
.append
(
Entry
{
request
,
response_tx
:
request_tx
,
input_length
,
});
self
.shared.batching_task
.notify_waiters
();
self
.shared.batching_task
.notify_waiters
();
match
request_rx
.await
.unwrap
()
{
match
request_rx
.await
.unwrap
()
{
Ok
(
output
)
=>
Ok
(
output
),
Ok
(
output
)
=>
Ok
(
output
),
Err
(
_
)
=>
Err
(
InferError
{}
),
Err
(
err
)
=>
Err
(
InferError
::
GenerationError
(
err
.to_string
())
),
}
}
}
}
}
}
...
@@ -108,7 +129,6 @@ async fn wrap_future(
...
@@ -108,7 +129,6 @@ async fn wrap_future(
next_batch
next_batch
}
}
Err
(
err
)
=>
{
Err
(
err
)
=>
{
println!
(
"{:?}"
,
err
);
send_error
(
err
,
request_ids
,
db
);
send_error
(
err
,
request_ids
,
db
);
None
None
}
}
...
@@ -117,14 +137,18 @@ async fn wrap_future(
...
@@ -117,14 +137,18 @@ async fn wrap_future(
fn
send_error
(
error
:
ClientError
,
request_ids
:
Vec
<
u64
>
,
db
:
&
Db
)
{
fn
send_error
(
error
:
ClientError
,
request_ids
:
Vec
<
u64
>
,
db
:
&
Db
)
{
request_ids
.into_iter
()
.for_each
(|
id
|
{
request_ids
.into_iter
()
.for_each
(|
id
|
{
let
(
_
,
response_tx
)
=
db
.remove
(
&
id
)
.unwrap
();
let
entry
=
db
.remove
(
&
id
)
.expect
(
"ID not found in db. This is a bug."
);
response_tx
.send
(
Err
(
error
.clone
()))
.unwrap_or
(());
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send
(
Err
(
error
.clone
()))
.unwrap_or
(());
});
});
}
}
fn
send_generated
(
finished
:
Vec
<
GeneratedText
>
,
db
:
&
Db
)
{
fn
send_generated
(
finished
:
Vec
<
GeneratedText
>
,
db
:
&
Db
)
{
finished
.into_iter
()
.for_each
(|
output
|
{
finished
.into_iter
()
.for_each
(|
output
|
{
let
(
_
,
response_tx
)
=
db
.remove
(
&
output
.request
.unwrap
()
.id
)
.unwrap
();
let
entry
=
db
response_tx
.send
(
Ok
(
output
.output
))
.unwrap_or
(());
.remove
(
&
output
.request
.unwrap
()
.id
)
.expect
(
"ID not found in db. This is a bug."
);
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send
(
Ok
(
output
.output
))
.unwrap_or
(());
});
});
}
}
router/src/db.rs
View file @
5e5d8766
/// This code is massively inspired by Tokio mini-redis
/// This code is massively inspired by Tokio mini-redis
use
crate
::
server
::
GenerateRequest
;
use
crate
::
server
::
{
GenerateParameters
,
GenerateRequest
}
;
use
bloom_inference_client
::{
Batch
,
ClientError
,
LogitsWarperParameters
,
Request
};
use
bloom_inference_client
::{
Batch
,
ClientError
,
LogitsWarperParameters
,
Request
};
use
parking_lot
::
RwLock
;
use
parking_lot
::
RwLock
;
use
std
::
collections
::
BTreeMap
;
use
std
::
collections
::
BTreeMap
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
tokio
::
sync
::
oneshot
::
Sender
;
use
tokio
::
sync
::
oneshot
::
Sender
;
#[derive(Debug)]
pub
(
crate
)
struct
Entry
{
pub
request
:
GenerateRequest
,
pub
response_tx
:
Sender
<
Result
<
String
,
ClientError
>>
,
pub
input_length
:
usize
,
}
impl
From
<
GenerateParameters
>
for
LogitsWarperParameters
{
fn
from
(
parameters
:
GenerateParameters
)
->
Self
{
Self
{
temperature
:
parameters
.temperature
,
top_k
:
parameters
.top_k
as
u32
,
top_p
:
parameters
.top_p
,
do_sample
:
parameters
.do_sample
,
}
}
}
#[derive(Debug,
Clone)]
#[derive(Debug,
Clone)]
pub
(
crate
)
struct
Db
{
pub
(
crate
)
struct
Db
{
pub
shared
:
Arc
<
Shared
>
,
pub
shared
:
Arc
<
Shared
>
,
...
@@ -18,7 +36,7 @@ pub struct Shared {
...
@@ -18,7 +36,7 @@ pub struct Shared {
#[derive(Debug)]
#[derive(Debug)]
struct
State
{
struct
State
{
entries
:
BTreeMap
<
u64
,
(
Request
,
Sender
<
Result
<
String
,
ClientError
>>
)
>
,
entries
:
BTreeMap
<
u64
,
Entry
>
,
/// Identifier to use for the next expiration. Each expiration is associated
/// Identifier to use for the next expiration. Each expiration is associated
/// with a unique identifier. See above for why.
/// with a unique identifier. See above for why.
...
@@ -44,37 +62,16 @@ impl Db {
...
@@ -44,37 +62,16 @@ impl Db {
Self
{
shared
}
Self
{
shared
}
}
}
pub
(
crate
)
fn
append
(
pub
(
crate
)
fn
append
(
&
self
,
entry
:
Entry
)
{
&
self
,
input_length
:
usize
,
request
:
GenerateRequest
,
sender
:
Sender
<
Result
<
String
,
ClientError
>>
,
)
{
let
mut
state
=
self
.shared.state
.write
();
let
mut
state
=
self
.shared.state
.write
();
let
id
=
state
.next_id
;
let
id
=
state
.next_id
;
state
.next_id
+=
1
;
state
.next_id
+=
1
;
let
parameters
=
Some
(
LogitsWarperParameters
{
state
.entries
.insert
(
id
,
entry
);
temperature
:
request
.parameters.temperature
,
top_k
:
request
.parameters.top_k
,
top_p
:
request
.parameters.top_p
,
do_sample
:
request
.parameters.do_sample
,
});
let
request
=
Request
{
id
,
inputs
:
request
.inputs
,
input_length
:
input_length
as
u32
,
parameters
,
max_new_tokens
:
request
.parameters.max_new_tokens
,
};
state
.entries
.insert
(
id
,
(
request
,
sender
));
}
}
pub
(
crate
)
fn
remove
(
pub
(
crate
)
fn
remove
(
&
self
,
id
:
&
u64
)
->
Option
<
Entry
>
{
&
self
,
id
:
&
u64
,
)
->
Option
<
(
Request
,
Sender
<
Result
<
String
,
ClientError
>>
)
>
{
let
mut
state
=
self
.shared.state
.write
();
let
mut
state
=
self
.shared.state
.write
();
state
.entries
.remove
(
id
)
state
.entries
.remove
(
id
)
}
}
...
@@ -91,7 +88,15 @@ impl Db {
...
@@ -91,7 +88,15 @@ impl Db {
.entries
.entries
.range
(
state
.next_batch_start_id
..
)
.range
(
state
.next_batch_start_id
..
)
.take
(
max_size
)
.take
(
max_size
)
.map
(|(
_
,
(
request
,
_
))|
request
.clone
())
.map
(|(
id
,
entry
)|
Request
{
id
:
*
id
,
inputs
:
entry
.request.inputs
.clone
(),
input_length
:
entry
.input_length
as
u32
,
parameters
:
Some
(
LogitsWarperParameters
::
from
(
entry
.request.parameters
.clone
(),
)),
max_new_tokens
:
entry
.request.parameters.max_new_tokens
,
})
.collect
();
.collect
();
if
requests
.is_empty
()
{
if
requests
.is_empty
()
{
...
...
router/src/lib.rs
0 → 100644
View file @
5e5d8766
mod
batcher
;
mod
db
;
pub
mod
server
;
mod
validation
;
use
batcher
::
Batcher
;
use
db
::{
Db
,
Entry
};
use
validation
::
Validation
;
router/src/main.rs
View file @
5e5d8766
use
bloom_inference_client
::
ShardedClient
;
use
bloom_inference_client
::
ShardedClient
;
use
std
::
net
::
SocketAddr
;
use
std
::
net
::
SocketAddr
;
use
std
::
time
::
Duration
;
use
text_generation_router
::
server
;
use
tokenizers
::
Tokenizer
;
use
tokenizers
::
Tokenizer
;
mod
server
;
mod
validation
;
use
validation
::
Validation
;
mod
db
;
use
db
::
Db
;
mod
batcher
;
use
batcher
::
Batcher
;
fn
main
()
->
Result
<
(),
std
::
io
::
Error
>
{
fn
main
()
->
Result
<
(),
std
::
io
::
Error
>
{
let
tokenizer
=
Tokenizer
::
from_pretrained
(
"bigscience/bloom"
,
None
)
.unwrap
();
let
tokenizer
=
Tokenizer
::
from_pretrained
(
"bigscience/bloom"
,
None
)
.unwrap
();
...
@@ -26,11 +13,9 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -26,11 +13,9 @@ fn main() -> Result<(), std::io::Error> {
.block_on
(
async
{
.block_on
(
async
{
tracing_subscriber
::
fmt
::
init
();
tracing_subscriber
::
fmt
::
init
();
let
sharded_client
=
ShardedClient
::
connect_uds
(
let
sharded_client
=
ShardedClient
::
connect_uds
(
"/tmp/bloom-inference-0"
.to_string
())
"/tmp/bloom-inference-0"
.to_string
(),
.await
Duration
::
from_secs
(
5
),
.expect
(
"Could not connect to server"
);
)
.await
;
sharded_client
sharded_client
.clear_cache
()
.clear_cache
()
.await
.await
...
...
router/src/server.rs
View file @
5e5d8766
use
bloom_inference_client
::
ShardedClient
;
use
crate
::{
Batcher
,
Validation
};
use
crate
::{
Batcher
,
Validation
};
use
axum
::
extract
::
Extension
;
use
axum
::
extract
::
Extension
;
use
axum
::
http
::
StatusCode
;
use
axum
::
http
::
StatusCode
;
use
axum
::
routing
::{
get
,
post
};
use
axum
::
routing
::{
get
,
post
};
use
axum
::{
Json
,
Router
};
use
axum
::{
Json
,
Router
};
use
bloom_inference_client
::
ShardedClient
;
use
serde
::
Deserialize
;
use
serde
::
Deserialize
;
use
std
::
net
::
SocketAddr
;
use
std
::
net
::
SocketAddr
;
use
tokenizers
::
Tokenizer
;
use
tokenizers
::
Tokenizer
;
...
@@ -15,7 +15,7 @@ pub(crate) struct GenerateParameters {
...
@@ -15,7 +15,7 @@ pub(crate) struct GenerateParameters {
#[serde(default
=
"default_temperature"
)]
#[serde(default
=
"default_temperature"
)]
pub
temperature
:
f32
,
pub
temperature
:
f32
,
#[serde(default
=
"default_top_k"
)]
#[serde(default
=
"default_top_k"
)]
pub
top_k
:
u
32
,
pub
top_k
:
i
32
,
#[serde(default
=
"default_top_p"
)]
#[serde(default
=
"default_top_p"
)]
pub
top_p
:
f32
,
pub
top_p
:
f32
,
#[serde(default
=
"default_do_sample"
)]
#[serde(default
=
"default_do_sample"
)]
...
@@ -28,7 +28,7 @@ fn default_temperature() -> f32 {
...
@@ -28,7 +28,7 @@ fn default_temperature() -> f32 {
1.0
1.0
}
}
fn
default_top_k
()
->
u
32
{
fn
default_top_k
()
->
i
32
{
0
0
}
}
...
@@ -62,8 +62,8 @@ pub(crate) struct GenerateRequest {
...
@@ -62,8 +62,8 @@ pub(crate) struct GenerateRequest {
}
}
#[instrument(skip(state),
fields(time,
time_per_token))]
#[instrument(skip(state),
fields(time,
time_per_token))]
async
fn
liveness
(
state
:
Extension
<
ServerState
>
)
->
Result
<
(),
StatusCode
>
{
async
fn
liveness
(
state
:
Extension
<
ServerState
>
)
->
Result
<
(),
(
StatusCode
,
String
)
>
{
let
output
=
state
state
.infer
.infer
.infer
(
.infer
(
1
,
1
,
...
@@ -78,50 +78,37 @@ async fn liveness(state: Extension<ServerState>) -> Result<(), StatusCode> {
...
@@ -78,50 +78,37 @@ async fn liveness(state: Extension<ServerState>) -> Result<(), StatusCode> {
},
},
},
},
)
)
.await
;
.await
?
;
Ok
(())
match
output
{
Ok
(
_
)
=>
Ok
(()),
Err
(
_
)
=>
Err
(
StatusCode
::
INTERNAL_SERVER_ERROR
),
}
}
}
#[instrument(skip(state),
fields(time,
time_per_token))]
#[instrument(skip(state),
fields(time,
time_per_token))]
async
fn
generate
(
async
fn
generate
(
state
:
Extension
<
ServerState
>
,
state
:
Extension
<
ServerState
>
,
req
:
Json
<
GenerateRequest
>
,
req
:
Json
<
GenerateRequest
>
,
)
->
Result
<
Json
<
serde_json
::
Value
>
,
StatusCode
>
{
)
->
Result
<
Json
<
serde_json
::
Value
>
,
(
StatusCode
,
String
)
>
{
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
let
(
input_length
,
validated_request
)
=
match
state
let
(
input_length
,
validated_request
)
=
state
.validation
.validation
.validate
(
GenerateRequest
{
.validate
(
GenerateRequest
{
inputs
:
req
.inputs
.clone
(),
inputs
:
req
.inputs
.clone
(),
parameters
:
req
.parameters
.clone
(),
parameters
:
req
.parameters
.clone
(),
})
})
.await
.await
?
;
{
Ok
(
result
)
=>
result
,
let
generated_text
=
state
.infer
.infer
(
input_length
,
validated_request
)
.await
?
;
Err
(
_
)
=>
return
Err
(
StatusCode
::
INTERNAL_SERVER_ERROR
),
};
tracing
::
Span
::
current
()
.record
(
"time"
,
format!
(
"{:?}"
,
start
.elapsed
()));
tracing
::
Span
::
current
()
.record
(
let
output
=
state
.infer
.infer
(
input_length
,
validated_request
)
.await
;
"time_per_token"
,
format!
(
"{:?}"
,
start
.elapsed
()
/
req
.parameters.max_new_tokens
),
match
output
{
);
Ok
(
generated_text
)
=>
{
tracing
::
info!
(
"response: {}"
,
generated_text
);
tracing
::
Span
::
current
()
.record
(
"time"
,
format!
(
"{:?}"
,
start
.elapsed
()));
tracing
::
Span
::
current
()
.record
(
Ok
(
Json
(
serde_json
::
json!
({
"time_per_token"
,
"generated_text"
:
generated_text
,
format!
(
"{:?}"
,
start
.elapsed
()
/
req
.parameters.max_new_tokens
),
})))
);
tracing
::
info!
(
"response: {}"
,
generated_text
);
Ok
(
Json
(
serde_json
::
json!
({
"generated_text"
:
generated_text
,
})))
}
Err
(
_
)
=>
Err
(
StatusCode
::
INTERNAL_SERVER_ERROR
),
}
}
}
#[derive(Clone)]
#[derive(Clone)]
...
...
router/src/validation.rs
View file @
5e5d8766
use
crate
::
server
::
GenerateRequest
;
use
crate
::
server
::
GenerateRequest
;
use
axum
::
http
::
StatusCode
;
use
thiserror
::
Error
;
use
tokenizers
::
tokenizer
::
Tokenizer
;
use
tokenizers
::
tokenizer
::
Tokenizer
;
use
tokio
::
sync
::{
mpsc
,
oneshot
};
use
tokio
::
sync
::{
mpsc
,
oneshot
};
#[derive(Debug)]
#[derive(Error,
Debug)]
pub
struct
ValidationError
{}
pub
enum
ValidationError
{
#[error(
"Temperature must be strictly positive"
)]
Temperature
,
#[error(
"Top p must be <= 0.0 or > 1.0"
)]
TopP
,
#[error(
"Top k must be strictly positive"
)]
TopK
,
#[error(
"Max New Tokens must be < 512"
)]
MaxNewTokens
,
#[error(
"Inputs must have less than 512 tokens. Given: {0}"
)]
InputLength
(
usize
),
}
impl
From
<
ValidationError
>
for
(
StatusCode
,
String
)
{
fn
from
(
err
:
ValidationError
)
->
Self
{
(
StatusCode
::
BAD_REQUEST
,
err
.to_string
())
}
}
type
ValidationRequest
=
(
type
ValidationRequest
=
(
GenerateRequest
,
GenerateRequest
,
...
@@ -39,15 +58,23 @@ impl Validation {
...
@@ -39,15 +58,23 @@ impl Validation {
async
fn
validation_task
(
tokenizer
:
Tokenizer
,
mut
receiver
:
mpsc
::
Receiver
<
ValidationRequest
>
)
{
async
fn
validation_task
(
tokenizer
:
Tokenizer
,
mut
receiver
:
mpsc
::
Receiver
<
ValidationRequest
>
)
{
while
let
Some
((
request
,
response_tx
))
=
receiver
.recv
()
.await
{
while
let
Some
((
request
,
response_tx
))
=
receiver
.recv
()
.await
{
if
request
.parameters.temperature
<
0.0
{
if
request
.parameters.temperature
<
0.0
{
response_tx
.send
(
Err
(
ValidationError
{}))
.unwrap_or
(());
response_tx
.send
(
Err
(
ValidationError
::
Temperature
))
.unwrap_or
(());
continue
;
continue
;
}
}
if
request
.parameters.top_p
<=
0.0
||
request
.parameters.top_p
>
1.0
{
if
request
.parameters.top_p
<=
0.0
||
request
.parameters.top_p
>
1.0
{
response_tx
.send
(
Err
(
ValidationError
{}))
.unwrap_or
(());
response_tx
.send
(
Err
(
ValidationError
::
TopP
))
.unwrap_or
(());
continue
;
}
if
request
.parameters.top_k
<
0
{
response_tx
.send
(
Err
(
ValidationError
::
TopK
))
.unwrap_or
(());
continue
;
continue
;
}
}
if
request
.parameters.max_new_tokens
>
512
{
if
request
.parameters.max_new_tokens
>
512
{
response_tx
.send
(
Err
(
ValidationError
{}))
.unwrap_or
(());
response_tx
.send
(
Err
(
ValidationError
::
MaxNewTokens
))
.unwrap_or
(());
continue
;
continue
;
}
}
...
@@ -55,11 +82,12 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
...
@@ -55,11 +82,12 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
let
input_length
=
inputs
.len
();
let
input_length
=
inputs
.len
();
if
input_length
>
512
{
if
input_length
>
512
{
response_tx
.send
(
Err
(
ValidationError
{}))
.unwrap_or
(());
response_tx
.send
(
Err
(
ValidationError
::
InputLength
(
input_length
)))
.unwrap_or
(());
continue
;
continue
;
}
}
response_tx
.send
(
Ok
((
input_length
,
request
)))
.unwrap_or
(());
response_tx
.send
(
Ok
((
input_length
,
request
)))
.unwrap_or
(());
}
}
println!
(
"drop here"
);
}
}
run.sh
100755 → 100644
View file @
5e5d8766
#!/usr/bin/env bash
#!/usr/bin/env bash
server_cmd
=
"python server/bloom_inference/main.py
$MODEL_NAME
--num-gpus
$NUM_GPUS
--shard-directory
$MODEL_BASE_PATH
"
server_cmd
=
"bloom-inference-server launcher
$MODEL_NAME
--num-gpus
$NUM_GPUS
--shard-directory
$MODEL_BASE_PATH
"
$server_cmd
&
FILE
=
/tmp/bloom-inference-0
# Run in background
$server_cmd
2>&1
>
/dev/null &
# Check if server is running by checking if the unix socket is created
FILE
=
/tmp/bloom-inference-0
while
:
while
:
do
do
if
test
-S
"
$FILE
"
;
then
if
test
-S
"
$FILE
"
;
then
...
@@ -18,4 +20,11 @@ while :
...
@@ -18,4 +20,11 @@ while :
sleep
1
sleep
1
exec
"bloom-inference"
# Run in background
text-generation-router &
# Wait for any process to exit
wait
-n
# Exit with status of process that exited first
exit
$?
\ No newline at end of file
server/bloom_inference/
main
.py
→
server/bloom_inference/
cli
.py
View file @
5e5d8766
...
@@ -4,13 +4,16 @@ from pathlib import Path
...
@@ -4,13 +4,16 @@ from pathlib import Path
from
torch.distributed.launcher
import
launch_agent
,
LaunchConfig
from
torch.distributed.launcher
import
launch_agent
,
LaunchConfig
from
typing
import
Optional
from
typing
import
Optional
from
bloom_inference
.server
import
serve
from
bloom_inference
import
serve
r
app
=
typer
.
Typer
()
def
main
(
model_name
:
str
,
@
app
.
command
()
num_gpus
:
int
=
1
,
def
launcher
(
shard_directory
:
Optional
[
Path
]
=
None
,
model_name
:
str
,
num_gpus
:
int
=
1
,
shard_directory
:
Optional
[
Path
]
=
None
,
):
):
if
num_gpus
==
1
:
if
num_gpus
==
1
:
serve
(
model_name
,
False
,
shard_directory
)
serve
(
model_name
,
False
,
shard_directory
)
...
@@ -23,8 +26,17 @@ def main(
...
@@ -23,8 +26,17 @@ def main(
rdzv_backend
=
"c10d"
,
rdzv_backend
=
"c10d"
,
max_restarts
=
0
,
max_restarts
=
0
,
)
)
launch_agent
(
config
,
serve
,
[
model_name
,
True
,
shard_directory
])
launch_agent
(
config
,
server
.
serve
,
[
model_name
,
True
,
shard_directory
])
@
app
.
command
()
def
serve
(
model_name
:
str
,
sharded
:
bool
=
False
,
shard_directory
:
Optional
[
Path
]
=
None
,
):
server
.
serve
(
model_name
,
sharded
,
shard_directory
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
typer
.
run
(
main
)
app
(
)
server/bloom_inference/server.py
View file @
5e5d8766
import
asyncio
import
asyncio
import
os
from
grpc
import
aio
from
grpc
import
aio
from
grpc_reflection.v1alpha
import
reflection
from
grpc_reflection.v1alpha
import
reflection
...
@@ -143,7 +145,3 @@ def serve(model_name, sharded, shard_directory):
...
@@ -143,7 +145,3 @@ def serve(model_name, sharded, shard_directory):
await
server
.
wait_for_termination
()
await
server
.
wait_for_termination
()
asyncio
.
run
(
serve_inner
(
model_name
,
sharded
,
shard_directory
))
asyncio
.
run
(
serve_inner
(
model_name
,
sharded
,
shard_directory
))
if
__name__
==
"__main__"
:
serve
(
"bigscience/bloom-560m"
,
True
,
Path
(
"/tmp/models"
))
server/bloom_inference/utils.py
View file @
5e5d8766
...
@@ -2,6 +2,8 @@ import os
...
@@ -2,6 +2,8 @@ import os
import
contextlib
import
contextlib
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
datetime
import
timedelta
from
transformers.generation_logits_process
import
(
from
transformers.generation_logits_process
import
(
LogitsProcessorList
,
LogitsProcessorList
,
TemperatureLogitsWarper
,
TemperatureLogitsWarper
,
...
@@ -79,6 +81,7 @@ def initialize_torch_distributed():
...
@@ -79,6 +81,7 @@ def initialize_torch_distributed():
backend
=
backend
,
backend
=
backend
,
world_size
=
world_size
,
world_size
=
world_size
,
rank
=
rank
,
rank
=
rank
,
timeout
=
timedelta
(
seconds
=
60
),
init_method
=
"tcp://localhost:6000"
,
init_method
=
"tcp://localhost:6000"
,
)
)
...
...
server/pyproject.toml
View file @
5e5d8766
...
@@ -4,6 +4,9 @@ version = "0.1.0"
...
@@ -4,6 +4,9 @@ version = "0.1.0"
description
=
"BLOOM Inference Python gRPC Server"
description
=
"BLOOM Inference Python gRPC Server"
authors
=
[
"Olivier Dehaene <olivier@huggingface.co>"
]
authors
=
[
"Olivier Dehaene <olivier@huggingface.co>"
]
[tool.poetry.scripts]
bloom-inference-server
=
'bloom_inference.cli:app'
[tool.poetry.dependencies]
[tool.poetry.dependencies]
python
=
"^3.9"
python
=
"^3.9"
protobuf
=
"^4.21.7"
protobuf
=
"^4.21.7"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment