Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
4c693e65
Commit
4c693e65
authored
Oct 11, 2022
by
Olivier Dehaene
Browse files
Refactored gRPC interface
Added validation logic
parent
fa9a0884
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
613 additions
and
362 deletions
+613
-362
README.md
README.md
+19
-2
proto/generate.proto
proto/generate.proto
+74
-39
router/client/src/client.rs
router/client/src/client.rs
+44
-17
router/client/src/lib.rs
router/client/src/lib.rs
+1
-3
router/client/src/sharded_client.rs
router/client/src/sharded_client.rs
+51
-13
router/src/batcher.rs
router/src/batcher.rs
+53
-69
router/src/db.rs
router/src/db.rs
+10
-0
router/src/main.rs
router/src/main.rs
+28
-14
router/src/server.rs
router/src/server.rs
+23
-17
router/src/validation.rs
router/src/validation.rs
+65
-0
server/bloom_inference/cache.py
server/bloom_inference/cache.py
+6
-31
server/bloom_inference/model.py
server/bloom_inference/model.py
+158
-134
server/bloom_inference/server.py
server/bloom_inference/server.py
+81
-23
No files found.
README.md
View file @
4c693e65
#
BLOOM
Inference
#
Text Generation
Inference
A Rust and gRPC server for BLOOM Inference.
A Rust and gRPC server for text generation inference.
## Load Tests
See
`k6/load_test.js`
We send the default examples with a 1 second delay between each request.
Stages:
-
Ramp up to 50 concurrent requests per second in 1min
-
Ramp up from 50 to 100 concurrent requests per second in 2min
-
Ramp down to 0 concurrent requests per second in 1min
| | avg | min | med | max | p(90) | p(95) | RPS |
|------------------------|-----------|-----------|-----------|------------|-----------|-----------|----------|
| Original code | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 |
| ISO with original code | 8.88s | 959.53ms | 8.89s | 17.08s | 13.34s | 14.12s | 5.94 |
| New batching logic |
**5.44s**
|
**1.27s**
|
**5.28s**
|
**13.12s**
|
**7.78s**
|
**8.92s**
|
**9.08**
|
## Install
...
...
proto/generate.proto
View file @
4c693e65
...
...
@@ -2,21 +2,35 @@ syntax = "proto3";
package
generate
.
v1
;
service
TextGeneration
{
service
TextGeneration
Service
{
/// Service discovery
rpc
ServiceDiscovery
(
Empty
)
returns
(
ServiceDiscoveryResponse
)
{}
rpc
ServiceDiscovery
(
ServiceDiscoveryRequest
)
returns
(
ServiceDiscoveryResponse
)
{}
/// Empties batch cache
rpc
ClearCache
(
Empty
)
returns
(
Empty
);
/// Generate tokens for a batch without cache
rpc
Generate
(
Batch
)
returns
(
Response
);
/// Generate tokens for a batch with cache
rpc
GenerateWithCache
(
BatchCached
)
returns
(
Response
);
rpc
ClearCache
(
ClearCacheRequest
)
returns
(
ClearCacheResponse
);
/// Generate tokens for a batch
rpc
Generate
(
GenerateRequest
)
returns
(
GenerateResponse
);
/// Generate tokens for a list of cached batches
rpc
GenerateWithCache
(
GenerateWithCacheRequest
)
returns
(
GenerateWithCacheResponse
);
/// Generate tokens until the text of at least one request of the batch is generated
rpc
GenerateUntilFinished
(
GenerateUntilFinishedRequest
)
returns
(
GenerateUntilFinishedResponse
);
/// Generate tokens until the text of at least one request of the cached batches i finished
rpc
GenerateUntilFinishedWithCache
(
GenerateUntilFinishedWithCacheRequest
)
returns
(
GenerateUntilFinishedWithCacheResponse
);
}
/// Empty request
message
ServiceDiscoveryRequest
{}
message
ServiceDiscoveryResponse
{
/// Other shards urls
repeated
string
urls
=
1
;
}
/// Empty request
message
ClearCacheRequest
{}
/// Empty response
message
ClearCacheResponse
{}
message
LogitsWarperParameters
{
float
temperature
=
1
;
uint32
top_k
=
2
;
...
...
@@ -29,10 +43,12 @@ message Request {
uint64
id
=
1
;
/// The generation context
string
inputs
=
2
;
/// The number of tokens inside inputs
uint32
input_length
=
3
;
/// Logits Warper Parameters
LogitsWarperParameters
parameters
=
3
;
LogitsWarperParameters
parameters
=
4
;
/// Stopping criteria
uint32
max_new_tokens
=
4
;
uint32
max_new_tokens
=
5
;
}
message
Batch
{
...
...
@@ -40,44 +56,63 @@ message Batch {
uint64
id
=
1
;
/// Individual requests
repeated
Request
requests
=
2
;
/// Batch size (==len(requests))
uint32
size
=
3
;
/// Length of the longest sequence within the batch (used for padding)
uint32
max_sequence_length
=
4
;
}
message
BatchCached
{
/// Batch ID
uint64
id
=
1
;
/// Request ids within cache
repeated
uint64
request_ids
=
2
;
/// Cache IDs
repeated
uint64
batch_cached_ids
=
3
;
/// Batch size (sum of all batch sizes)
uint32
total_batch_size
=
4
;
/// Max sequence length
uint32
max_sequence_length
=
5
;
}
message
FinishedGeneration
{
/// ID of the original request
uint64
id
=
1
;
message
GeneratedText
{
/// Request
Request
request
=
1
;
/// Output
string
output
=
2
;
}
message
CacheEntry
{
/// Cache ID; same as batch ID
uint64
id
=
1
;
/// Requests present in cache entry
repeated
uint64
request_ids
=
2
;
/// Sequence length
uint32
sequence_length
=
3
;
message
GenerateRequest
{
/// Batch
Batch
batch
=
1
;
}
message
Response
{
/// Finished requests
(optional)
repeated
FinishedGeneration
finished
=
1
;
///
Cache entry (optional
)
optional
CacheEntry
cache_entry
=
2
;
message
Generate
Response
{
/// Finished requests
repeated
GeneratedText
generated_texts
=
1
;
///
Next batch (cached
)
optional
Batch
batch
=
2
;
}
message
GenerateWithCacheRequest
{
/// Cached batches
repeated
Batch
batches
=
1
;
}
// Represent an empty message.
message
Empty
{}
\ No newline at end of file
message
GenerateWithCacheResponse
{
/// Finished requests
repeated
GeneratedText
generated_texts
=
1
;
/// Next batch (cached)
optional
Batch
batch
=
2
;
}
message
GenerateUntilFinishedRequest
{
/// Batch
Batch
batch
=
1
;
}
message
GenerateUntilFinishedResponse
{
/// Finished requests
repeated
GeneratedText
generated_texts
=
1
;
/// Next batch (cached)
optional
Batch
batch
=
2
;
}
message
GenerateUntilFinishedWithCacheRequest
{
/// Cached batches
repeated
Batch
batches
=
1
;
}
message
GenerateUntilFinishedWithCacheResponse
{
/// Finished requests
repeated
GeneratedText
generated_texts
=
1
;
/// Next batch (cached)
optional
Batch
batch
=
2
;
}
router/client/src/client.rs
View file @
4c693e65
use
crate
::
pb
::
generate
::
v1
::
text_generation_client
::
TextGenerationClient
;
use
crate
::
pb
::
generate
::
v1
::
text_generation_
service_
client
::
TextGeneration
Service
Client
;
use
crate
::
pb
::
generate
::
v1
::
*
;
use
crate
::
Result
;
use
std
::
time
::
Duration
;
...
...
@@ -9,7 +9,7 @@ use tracing::*;
/// BLOOM Inference gRPC client
#[derive(Clone)]
pub
struct
Client
{
stub
:
TextGenerationClient
<
Timeout
<
Channel
>>
,
stub
:
TextGeneration
Service
Client
<
Timeout
<
Channel
>>
,
}
impl
Client
{
...
...
@@ -22,13 +22,13 @@ impl Client {
let
timeout_channel
=
Timeout
::
new
(
channel
,
timeout
);
Self
{
stub
:
TextGenerationClient
::
new
(
timeout_channel
),
stub
:
TextGeneration
Service
Client
::
new
(
timeout_channel
),
}
}
/// Returns a client connected to the given unix socket. Requests exceeding timeout will fail.
pub
async
fn
connect_uds
(
path
:
String
,
timeout
:
Duration
)
->
Self
{
let
channel
=
Channel
::
from_shared
(
format!
(
"http://[::]:50051"
))
let
channel
=
Channel
::
from_shared
(
"http://[::]:50051"
.to_string
(
))
.unwrap
()
.connect_with_connector
(
tower
::
service_fn
(
move
|
_
:
Uri
|
{
tokio
::
net
::
UnixStream
::
connect
(
path
.clone
())
...
...
@@ -38,13 +38,13 @@ impl Client {
let
timeout_channel
=
Timeout
::
new
(
channel
,
timeout
);
Self
{
stub
:
TextGenerationClient
::
new
(
timeout_channel
),
stub
:
TextGeneration
Service
Client
::
new
(
timeout_channel
),
}
}
#[instrument(skip(self))]
pub
async
fn
service_discovery
(
&
mut
self
)
->
Result
<
Vec
<
String
>>
{
let
request
=
tonic
::
Request
::
new
(
Empty
{});
let
request
=
tonic
::
Request
::
new
(
ServiceDiscoveryRequest
{});
let
response
=
self
.stub
.service_discovery
(
request
)
...
...
@@ -64,7 +64,7 @@ impl Client {
#[instrument(skip(self))]
pub
async
fn
clear_cache
(
&
mut
self
)
->
Result
<
()
>
{
let
request
=
tonic
::
Request
::
new
(
Empty
{});
let
request
=
tonic
::
Request
::
new
(
ClearCacheRequest
{});
self
.stub
.clear_cache
(
request
)
.instrument
(
info_span!
(
"clear_cache"
))
...
...
@@ -73,32 +73,59 @@ impl Client {
}
#[instrument(skip(self))]
pub
async
fn
generate
(
&
mut
self
,
request
:
Batch
,
)
->
Result
<
(
Vec
<
FinishedGeneration
>
,
Option
<
CacheEntry
>
)
>
{
let
request
=
tonic
::
Request
::
new
(
request
);
pub
async
fn
generate
(
&
mut
self
,
batch
:
Batch
)
->
Result
<
(
Vec
<
GeneratedText
>
,
Option
<
Batch
>
)
>
{
let
request
=
tonic
::
Request
::
new
(
GenerateRequest
{
batch
:
Some
(
batch
)
});
let
response
=
self
.stub
.generate
(
request
)
.instrument
(
info_span!
(
"generate"
))
.await
?
.into_inner
();
Ok
((
response
.
finished
,
response
.cache_entry
))
Ok
((
response
.
generated_texts
,
response
.batch
))
}
#[instrument(skip(self))]
pub
async
fn
generate_with_cache
(
&
mut
self
,
requ
es
t
:
Batch
Cached
,
)
->
Result
<
(
Vec
<
Finished
Generat
ion
>
,
Option
<
CacheEntry
>
)
>
{
let
request
=
tonic
::
Request
::
new
(
request
);
batch
es
:
Vec
<
Batch
>
,
)
->
Result
<
(
Vec
<
Generat
edText
>
,
Option
<
Batch
>
)
>
{
let
request
=
tonic
::
Request
::
new
(
GenerateWithCacheRequest
{
batches
}
);
let
response
=
self
.stub
.generate_with_cache
(
request
)
.instrument
(
info_span!
(
"generate_with_cache"
))
.await
?
.into_inner
();
Ok
((
response
.finished
,
response
.cache_entry
))
Ok
((
response
.generated_texts
,
response
.batch
))
}
#[instrument(skip(self))]
pub
async
fn
generate_until_finished
(
&
mut
self
,
batch
:
Batch
,
)
->
Result
<
(
Vec
<
GeneratedText
>
,
Option
<
Batch
>
)
>
{
let
request
=
tonic
::
Request
::
new
(
GenerateUntilFinishedRequest
{
batch
:
Some
(
batch
)
});
let
response
=
self
.stub
.generate_until_finished
(
request
)
.instrument
(
info_span!
(
"generate_until_finished"
))
.await
?
.into_inner
();
Ok
((
response
.generated_texts
,
response
.batch
))
}
#[instrument(skip(self))]
pub
async
fn
generate_until_finished_with_cache
(
&
mut
self
,
batches
:
Vec
<
Batch
>
,
)
->
Result
<
(
Vec
<
GeneratedText
>
,
Option
<
Batch
>
)
>
{
let
request
=
tonic
::
Request
::
new
(
GenerateUntilFinishedWithCacheRequest
{
batches
});
let
response
=
self
.stub
.generate_until_finished_with_cache
(
request
)
.instrument
(
info_span!
(
"generate_until_finished_with_cache"
))
.await
?
.into_inner
();
Ok
((
response
.generated_texts
,
response
.batch
))
}
}
router/client/src/lib.rs
View file @
4c693e65
...
...
@@ -5,9 +5,7 @@ mod pb;
mod
sharded_client
;
pub
use
client
::
Client
;
pub
use
pb
::
generate
::
v1
::{
Batch
,
BatchCached
,
CacheEntry
,
FinishedGeneration
,
LogitsWarperParameters
,
Request
,
};
pub
use
pb
::
generate
::
v1
::{
Batch
,
GeneratedText
,
LogitsWarperParameters
,
Request
};
pub
use
sharded_client
::
ShardedClient
;
use
thiserror
::
Error
;
pub
use
tonic
::
transport
::
Uri
;
...
...
router/client/src/sharded_client.rs
View file @
4c693e65
use
crate
::
Result
;
use
crate
::{
Batch
,
BatchCached
,
CacheEntry
,
Client
,
Finished
Generat
ion
};
use
crate
::{
Batch
,
Client
,
Generat
edText
};
use
futures
::
future
::
join_all
;
use
std
::
time
::
Duration
;
use
tokio
::
sync
::{
broadcast
,
mpsc
};
...
...
@@ -9,11 +9,19 @@ use tonic::transport::Uri;
enum
Command
{
Generate
(
Batch
,
mpsc
::
Sender
<
Result
<
(
Vec
<
Finished
Generat
ion
>
,
Option
<
CacheEntry
>
)
>>
,
mpsc
::
Sender
<
Result
<
(
Vec
<
Generat
edText
>
,
Option
<
Batch
>
)
>>
,
),
GenerateWithCache
(
BatchCached
,
mpsc
::
Sender
<
Result
<
(
Vec
<
FinishedGeneration
>
,
Option
<
CacheEntry
>
)
>>
,
Vec
<
Batch
>
,
mpsc
::
Sender
<
Result
<
(
Vec
<
GeneratedText
>
,
Option
<
Batch
>
)
>>
,
),
GenerateUntilFinished
(
Batch
,
mpsc
::
Sender
<
Result
<
(
Vec
<
GeneratedText
>
,
Option
<
Batch
>
)
>>
,
),
GenerateUntilFinishedWithCache
(
Vec
<
Batch
>
,
mpsc
::
Sender
<
Result
<
(
Vec
<
GeneratedText
>
,
Option
<
Batch
>
)
>>
,
),
ClearCache
(
mpsc
::
Sender
<
Result
<
()
>>
),
}
...
...
@@ -25,8 +33,16 @@ async fn client_task(mut client: Client, mut request_subscriber: broadcast::Rece
let
result
=
client
.generate
(
batch
)
.await
;
response_tx
.try_send
(
result
)
.unwrap_or
(());
}
Command
::
GenerateWithCache
(
batch_cached
,
response_tx
)
=>
{
let
result
=
client
.generate_with_cache
(
batch_cached
)
.await
;
Command
::
GenerateWithCache
(
batches
,
response_tx
)
=>
{
let
result
=
client
.generate_with_cache
(
batches
)
.await
;
response_tx
.try_send
(
result
)
.unwrap_or
(());
}
Command
::
GenerateUntilFinished
(
batch
,
response_tx
)
=>
{
let
result
=
client
.generate_until_finished
(
batch
)
.await
;
response_tx
.try_send
(
result
)
.unwrap_or
(());
}
Command
::
GenerateUntilFinishedWithCache
(
batches
,
response_tx
)
=>
{
let
result
=
client
.generate_until_finished_with_cache
(
batches
)
.await
;
response_tx
.try_send
(
result
)
.unwrap_or
(());
}
Command
::
ClearCache
(
response_tx
)
=>
{
...
...
@@ -74,10 +90,7 @@ impl ShardedClient {
Self
::
from_master_client
(
master_client
)
.await
}
pub
async
fn
generate
(
&
self
,
batch
:
Batch
,
)
->
Result
<
(
Vec
<
FinishedGeneration
>
,
Option
<
CacheEntry
>
)
>
{
pub
async
fn
generate
(
&
self
,
batch
:
Batch
)
->
Result
<
(
Vec
<
GeneratedText
>
,
Option
<
Batch
>
)
>
{
let
(
response_tx
,
mut
response_rx
)
=
mpsc
::
channel
(
1
);
self
.request_tx
.send
(
Command
::
Generate
(
batch
,
response_tx
))
...
...
@@ -87,11 +100,36 @@ impl ShardedClient {
pub
async
fn
generate_with_cache
(
&
self
,
batch_cached
:
BatchCached
,
)
->
Result
<
(
Vec
<
FinishedGeneration
>
,
Option
<
CacheEntry
>
)
>
{
batches
:
Vec
<
Batch
>
,
)
->
Result
<
(
Vec
<
GeneratedText
>
,
Option
<
Batch
>
)
>
{
let
(
response_tx
,
mut
response_rx
)
=
mpsc
::
channel
(
1
);
self
.request_tx
.send
(
Command
::
GenerateWithCache
(
batches
,
response_tx
))
.unwrap
();
response_rx
.recv
()
.await
.unwrap
()
}
pub
async
fn
generate_until_finished
(
&
self
,
batch
:
Batch
,
)
->
Result
<
(
Vec
<
GeneratedText
>
,
Option
<
Batch
>
)
>
{
let
(
response_tx
,
mut
response_rx
)
=
mpsc
::
channel
(
1
);
self
.request_tx
.send
(
Command
::
GenerateUntilFinished
(
batch
,
response_tx
))
.unwrap
();
response_rx
.recv
()
.await
.unwrap
()
}
pub
async
fn
generate_until_finished_with_cache
(
&
self
,
batches
:
Vec
<
Batch
>
,
)
->
Result
<
(
Vec
<
GeneratedText
>
,
Option
<
Batch
>
)
>
{
let
(
response_tx
,
mut
response_rx
)
=
mpsc
::
channel
(
1
);
self
.request_tx
.send
(
Command
::
GenerateWithCache
(
batch_cached
,
response_tx
))
.send
(
Command
::
GenerateUntilFinishedWithCache
(
batches
,
response_tx
,
))
.unwrap
();
response_rx
.recv
()
.await
.unwrap
()
}
...
...
router/src/batcher.rs
View file @
4c693e65
use
crate
::
server
::
GenerateRequest
;
use
crate
::
Db
;
use
bloom_inference_client
::{
Batch
,
BatchCached
,
CacheEntry
,
ClientError
,
FinishedGeneration
,
ShardedClient
,
};
use
bloom_inference_client
::{
Batch
,
ClientError
,
GeneratedText
,
ShardedClient
};
use
std
::
future
::
Future
;
use
std
::
sync
::
Arc
;
use
tokio
::
sync
::{
Notify
,
oneshot
};
use
crate
::
server
::
GenerateRequest
;
use
tokio
::
sync
::{
oneshot
,
Notify
};
const
MAX_LENGTH
:
usize
=
128
;
...
...
@@ -32,12 +31,16 @@ impl Batcher {
Self
{
db
,
shared
}
}
pub
(
crate
)
async
fn
infer
(
&
self
,
request
:
GenerateRequest
)
->
Result
<
String
,
InferError
>
{
pub
(
crate
)
async
fn
infer
(
&
self
,
input_length
:
usize
,
request
:
GenerateRequest
,
)
->
Result
<
String
,
InferError
>
{
if
self
.db
.len
()
>
MAX_LENGTH
{
return
Err
(
InferError
{});
}
let
(
request_tx
,
request_rx
)
=
oneshot
::
channel
();
self
.db
.append
(
request
,
request_tx
);
self
.db
.append
(
input_length
,
request
,
request_tx
);
self
.shared.batching_task
.notify_waiters
();
match
request_rx
.await
.unwrap
()
{
Ok
(
output
)
=>
Ok
(
output
),
...
...
@@ -51,76 +54,57 @@ async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
shared
.batching_task
.notified
()
.await
;
if
let
Some
(
batch
)
=
db
.next_batch
(
32
)
{
let
mut
cache_entry
=
infer_batch
(
batch
,
&
client
,
&
db
)
.await
;
loop
{
if
let
Some
(
entry
)
=
cache_entry
{
let
mut
batch_cached_ids
=
vec!
[
entry
.id
];
let
mut
total_batch_size
=
entry
.request_ids
.len
();
let
mut
max_sequence_length
=
entry
.sequence_length
;
let
mut
request_ids
=
entry
.request_ids
;
// if total_batch_size <= 16 {
// if let Some(batch) = db.next_batch_minimum_size(16, 48) {
// let other_cache_entry = infer_batch(batch, &client, &db).await;
//
// if let Some(entry) = other_cache_entry {
// batch_cached_ids.push(entry.id);
// total_batch_size += entry.request_ids.len();
// max_sequence_length =
// max_sequence_length.max(entry.sequence_length);
// request_ids.extend(entry.request_ids.into_iter());
// }
// }
// }
let
batch_cached
=
BatchCached
{
id
:
entry
.id
,
batch_cached_ids
,
total_batch_size
:
total_batch_size
as
u32
,
max_sequence_length
,
request_ids
,
};
cache_entry
=
infer_batch_cached
(
batch_cached
,
&
client
,
&
db
)
.await
;
}
else
{
break
;
let
request_ids
=
batch
.requests
.iter
()
.map
(|
req
|
req
.id
)
.collect
();
let
mut
cached_batch
=
match
batch
.size
{
size
if
size
>
16
=>
{
wrap_future
(
client
.generate_until_finished
(
batch
),
request_ids
,
&
db
)
.await
}
_
=>
wrap_future
(
client
.generate
(
batch
),
request_ids
,
&
db
)
.await
,
};
while
let
Some
(
batch
)
=
cached_batch
{
let
batch_size
=
batch
.size
;
let
mut
request_ids
:
Vec
<
u64
>
=
batch
.requests
.iter
()
.map
(|
req
|
req
.id
)
.collect
();
let
mut
batches
=
vec!
[
batch
];
if
batch_size
<=
16
{
if
let
Some
(
new_batch
)
=
db
.next_batch_minimum_size
(
16
,
48
)
{
let
new_batch_request_ids
=
new_batch
.requests
.iter
()
.map
(|
req
|
req
.id
)
.collect
();
let
new_cached_batch
=
wrap_future
(
client
.generate
(
new_batch
),
new_batch_request_ids
,
&
db
)
.await
;
if
let
Some
(
new_cached_batch
)
=
new_cached_batch
{
request_ids
.extend
(
new_cached_batch
.requests
.iter
()
.map
(|
req
|
req
.id
));
batches
.push
(
new_cached_batch
);
}
}
}
cached_batch
=
match
batch_size
{
size
if
size
>
16
=>
{
wrap_future
(
client
.generate_until_finished_with_cache
(
batches
),
request_ids
,
&
db
)
.await
}
_
=>
wrap_future
(
client
.generate_with_cache
(
batches
),
request_ids
,
&
db
)
.await
,
};
}
}
}
}
async
fn
infer_batch_cached
(
batch
:
BatchCached
,
client
:
&
ShardedClient
,
async
fn
wrap_future
(
future
:
impl
Future
<
Output
=
Result
<
(
Vec
<
GeneratedText
>
,
Option
<
Batch
>
),
ClientError
>>
,
request_ids
:
Vec
<
u64
>
,
db
:
&
Db
,
)
->
Option
<
CacheEntry
>
{
match
client
.generate_with_cache
(
batch
.clone
())
.await
{
Ok
((
finished
,
cache_entry
))
=>
{
send_finished
(
finished
,
db
);
cache_entry
}
Err
(
err
)
=>
{
println!
(
"{:?}"
,
err
);
send_error
(
err
,
batch
.request_ids
,
&
db
);
None
}
}
}
async
fn
infer_batch
(
batch
:
Batch
,
client
:
&
ShardedClient
,
db
:
&
Db
)
->
Option
<
CacheEntry
>
{
match
client
.generate
(
batch
.clone
())
.await
{
Ok
((
finished
,
cache_entry
))
=>
{
send_finished
(
finished
,
db
);
cache_entry
)
->
Option
<
Batch
>
{
match
future
.await
{
Ok
((
generated_texts
,
next_batch
))
=>
{
send_generated
(
generated_texts
,
db
);
next_batch
}
Err
(
err
)
=>
{
println!
(
"{:?}"
,
err
);
send_error
(
err
,
batch
.requests
.into_iter
()
.map
(|
req
|
req
.id
)
.collect
(),
&
db
,
);
send_error
(
err
,
request_ids
,
db
);
None
}
}
...
...
@@ -133,9 +117,9 @@ fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
});
}
fn
send_
finish
ed
(
finished
:
Vec
<
Finished
Generat
ion
>
,
db
:
&
Db
)
{
fn
send_
generat
ed
(
finished
:
Vec
<
Generat
edText
>
,
db
:
&
Db
)
{
finished
.into_iter
()
.for_each
(|
output
|
{
let
(
_
,
response_tx
)
=
db
.remove
(
&
output
.id
)
.unwrap
();
let
(
_
,
response_tx
)
=
db
.remove
(
&
output
.
request
.unwrap
()
.
id
)
.unwrap
();
response_tx
.send
(
Ok
(
output
.output
))
.unwrap_or
(());
});
}
router/src/db.rs
View file @
4c693e65
...
...
@@ -46,6 +46,7 @@ impl Db {
pub
(
crate
)
fn
append
(
&
self
,
input_length
:
usize
,
request
:
GenerateRequest
,
sender
:
Sender
<
Result
<
String
,
ClientError
>>
,
)
{
...
...
@@ -63,6 +64,7 @@ impl Db {
let
request
=
Request
{
id
,
inputs
:
request
.inputs
,
input_length
:
input_length
as
u32
,
parameters
,
max_new_tokens
:
request
.parameters.max_new_tokens
,
};
...
...
@@ -103,9 +105,13 @@ impl Db {
pub
(
crate
)
fn
next_batch
(
&
self
,
max_size
:
usize
)
->
Option
<
Batch
>
{
if
let
Some
((
last_id
,
requests
))
=
self
.next_requests
(
max_size
)
{
let
mut
state
=
self
.shared.state
.write
();
let
size
=
requests
.len
();
let
max_sequence_length
=
requests
.iter
()
.map
(|
r
|
r
.input_length
)
.max
()
.unwrap
();
let
batch
=
Batch
{
id
:
state
.next_batch_id
,
requests
,
size
:
size
as
u32
,
max_sequence_length
,
};
state
.next_batch_start_id
=
last_id
+
1
;
state
.next_batch_id
+=
1
;
...
...
@@ -122,9 +128,13 @@ impl Db {
if
let
Some
((
last_id
,
requests
))
=
self
.next_requests
(
max_size
)
{
if
requests
.len
()
>=
min_size
{
let
mut
state
=
self
.shared.state
.write
();
let
size
=
requests
.len
();
let
max_sequence_length
=
requests
.iter
()
.map
(|
r
|
r
.input_length
)
.max
()
.unwrap
();
let
batch
=
Batch
{
id
:
state
.next_batch_id
,
requests
,
size
:
size
as
u32
,
max_sequence_length
,
};
state
.next_batch_start_id
=
last_id
+
1
;
state
.next_batch_id
+=
1
;
...
...
router/src/main.rs
View file @
4c693e65
use
bloom_inference_client
::
ShardedClient
;
use
poem
;
use
poem
::
listener
::
TcpListener
;
use
std
::
time
::
Duration
;
use
tokenizers
::
Tokenizer
;
mod
server
;
mod
validation
;
use
validation
::
Validation
;
mod
db
;
use
db
::
Db
;
mod
batcher
;
use
batcher
::
Batcher
;
#[tokio::main]
async
fn
main
()
->
Result
<
(),
std
::
io
::
Error
>
{
tracing_subscriber
::
fmt
::
init
();
fn
main
()
->
Result
<
(),
std
::
io
::
Error
>
{
let
tokenizer
=
Tokenizer
::
from_pretrained
(
"bigscience/bloom"
,
None
)
.unwrap
();
tokio
::
runtime
::
Builder
::
new_multi_thread
()
.enable_all
()
.build
()
.unwrap
()
.block_on
(
async
{
tracing_subscriber
::
fmt
::
init
();
let
sharded_client
=
ShardedClient
::
connect_uds
(
"/tmp/bloom-inference-0"
.to_string
(),
Duration
::
from_secs
(
5
))
let
sharded_client
=
ShardedClient
::
connect_uds
(
"/tmp/bloom-inference-0"
.to_string
(),
Duration
::
from_secs
(
5
),
)
.await
;
sharded_client
.clear_cache
()
.await
.expect
(
"Unable to clear cache"
);
tracing
::
info!
(
"Connected"
);
sharded_client
.clear_cache
()
.await
.expect
(
"Unable to clear cache"
);
tracing
::
info!
(
"Connected"
);
let
addr
=
"127.0.0.1:3000"
.to_string
();
let
listener
=
TcpListener
::
bind
(
addr
);
let
addr
=
"127.0.0.1:3000"
.to_string
();
let
listener
=
TcpListener
::
bind
(
addr
);
server
::
run
(
sharded_client
,
listener
)
.await
server
::
run
(
sharded_client
,
tokenizer
,
listener
)
.await
})
}
router/src/server.rs
View file @
4c693e65
use
poem
::{
EndpointExt
,
ha
ndler
,
post
,
Route
,
Server
};
use
crate
::{
Batcher
,
S
ha
rdedClient
,
Validation
};
use
poem
::
http
::
StatusCode
;
use
poem
::
listener
::
TcpListener
;
use
poem
::
middleware
::
AddData
;
use
poem
::
web
::{
Data
,
Json
};
use
poem
::{
handler
,
post
,
EndpointExt
,
Route
,
Server
};
use
serde
::
Deserialize
;
use
tokenizers
::
Tokenizer
;
use
tokio
::
time
::
Instant
;
use
crate
::{
Batcher
,
ShardedClient
};
use
tracing
::
instrument
;
use
serde
::
Deserialize
;
#[derive(Clone,
Debug,
Deserialize)]
pub
(
crate
)
struct
GenerateParameters
{
...
...
@@ -59,21 +60,24 @@ pub(crate) struct GenerateRequest {
pub
parameters
:
GenerateParameters
,
}
#[handler]
#[instrument(skip(infer),
fields(time,
time_per_token))]
#[instrument(skip(
validation,
infer),
fields(time,
time_per_token))]
async
fn
generate
(
validation
:
Data
<&
Validation
>
,
infer
:
Data
<&
Batcher
>
,
req
:
Json
<
GenerateRequest
>
,
)
->
poem
::
Result
<
Json
<
serde_json
::
Value
>>
{
let
start
=
Instant
::
now
();
let
output
=
infer
.
infer
(
GenerateRequest
{
let
(
input_length
,
validated_request
)
=
validation
.
validate
(
GenerateRequest
{
inputs
:
req
.inputs
.clone
(),
parameters
:
req
.parameters
.clone
(),
})
.await
;
.await
.unwrap
();
let
output
=
infer
.infer
(
input_length
,
validated_request
)
.await
;
match
output
{
Ok
(
generated_text
)
=>
{
...
...
@@ -92,20 +96,22 @@ async fn generate(
}
}
pub
async
fn
run
(
client
:
ShardedClient
,
listener
:
TcpListener
<
String
>
)
->
Result
<
(),
std
::
io
::
Error
>
{
client
.clear_cache
()
.await
.expect
(
"Unable to clear cache"
);
pub
async
fn
run
(
client
:
ShardedClient
,
tokenizer
:
Tokenizer
,
listener
:
TcpListener
<
String
>
,
)
->
Result
<
(),
std
::
io
::
Error
>
{
client
.clear_cache
()
.await
.expect
(
"Unable to clear cache"
);
tracing
::
info!
(
"Connected"
);
let
infer
=
Batcher
::
new
(
client
);
let
validation
=
Validation
::
new
(
tokenizer
);
let
app
=
Route
::
new
()
.at
(
"/generate"
,
post
(
generate
))
.with
(
AddData
::
new
(
validation
))
.with
(
AddData
::
new
(
infer
));
Server
::
new
(
listener
)
.run
(
app
)
.await
}
\ No newline at end of file
Server
::
new
(
listener
)
.run
(
app
)
.await
}
router/src/validation.rs
0 → 100644
View file @
4c693e65
use
crate
::
server
::
GenerateRequest
;
use
tokenizers
::
tokenizer
::
Tokenizer
;
use
tokio
::
sync
::{
mpsc
,
oneshot
};
#[derive(Debug)]
pub
struct
ValidationError
{}
type
ValidationRequest
=
(
GenerateRequest
,
oneshot
::
Sender
<
Result
<
(
usize
,
GenerateRequest
),
ValidationError
>>
,
);
#[derive(Debug,
Clone)]
pub
(
crate
)
struct
Validation
{
sender
:
mpsc
::
Sender
<
ValidationRequest
>
,
}
impl
Validation
{
pub
(
crate
)
fn
new
(
tokenizer
:
Tokenizer
)
->
Self
{
let
(
validation_sender
,
validation_receiver
)
=
mpsc
::
channel
(
128
);
tokio
::
spawn
(
validation_task
(
tokenizer
,
validation_receiver
));
Self
{
sender
:
validation_sender
,
}
}
pub
(
crate
)
async
fn
validate
(
&
self
,
request
:
GenerateRequest
,
)
->
Result
<
(
usize
,
GenerateRequest
),
ValidationError
>
{
let
(
sender
,
receiver
)
=
oneshot
::
channel
();
self
.sender
.send
((
request
,
sender
))
.await
.unwrap
();
receiver
.await
.unwrap
()
}
}
async
fn
validation_task
(
tokenizer
:
Tokenizer
,
mut
receiver
:
mpsc
::
Receiver
<
ValidationRequest
>
)
{
while
let
Some
((
request
,
response_tx
))
=
receiver
.recv
()
.await
{
if
request
.parameters.temperature
<
0.0
{
response_tx
.send
(
Err
(
ValidationError
{}))
.unwrap_or
(());
continue
;
}
if
request
.parameters.top_p
<=
0.0
||
request
.parameters.top_p
>
1.0
{
response_tx
.send
(
Err
(
ValidationError
{}))
.unwrap_or
(());
continue
;
}
if
request
.parameters.max_new_tokens
>
512
{
response_tx
.send
(
Err
(
ValidationError
{}))
.unwrap_or
(());
continue
;
}
let
inputs
=
tokenizer
.encode
(
request
.inputs
.clone
(),
false
)
.unwrap
();
let
input_length
=
inputs
.len
();
if
input_length
>
512
{
response_tx
.send
(
Err
(
ValidationError
{}))
.unwrap_or
(());
continue
;
}
response_tx
.send
(
Ok
((
input_length
,
request
)))
.unwrap_or
(());
}
println!
(
"drop here"
);
}
server/bloom_inference/cache.py
View file @
4c693e65
import
torch
from
dataclasses
import
dataclass
from
typing
import
Dict
,
Optional
,
List
from
bloom_inference.pb
import
generate_pb2
from
bloom_inference.utils
import
NextTokenChooser
,
StoppingCriteria
@
dataclass
class
CacheEntry
:
batch_id
:
int
request_ids
:
List
[
int
]
input_ids
:
Dict
[
str
,
torch
.
Tensor
]
all_input_ids
:
List
[
torch
.
Tensor
]
next_token_choosers
:
List
[
NextTokenChooser
]
stopping_criterias
:
List
[
StoppingCriteria
]
def
__len__
(
self
):
return
len
(
self
.
request_ids
)
def
to_pb
(
self
):
return
generate_pb2
.
CacheEntry
(
id
=
self
.
batch_id
,
request_ids
=
self
.
request_ids
,
sequence_length
=
max
(
len
(
entry
)
for
entry
in
self
.
all_input_ids
),
)
from
bloom_inference.model
import
Batch
from
typing
import
Dict
,
Optional
class
Cache
:
def
__init__
(
self
):
self
.
cache
:
Dict
[
str
,
CacheEntry
]
=
{}
self
.
cache
:
Dict
[
int
,
Batch
]
=
{}
def
pop
(
self
,
batch_id
:
str
)
->
Optional
[
CacheEntry
]:
def
pop
(
self
,
batch_id
:
int
)
->
Optional
[
Batch
]:
return
self
.
cache
.
pop
(
batch_id
,
None
)
def
set
(
self
,
entry
:
CacheEntry
):
def
set
(
self
,
entry
:
Batch
):
if
entry
is
not
None
:
self
.
cache
[
entry
.
batch_id
]
=
entry
def
delete
(
self
,
batch_id
:
str
):
def
delete
(
self
,
batch_id
:
int
):
del
self
.
cache
[
batch_id
]
def
clear
(
self
):
...
...
server/bloom_inference/model.py
View file @
4c693e65
...
...
@@ -8,7 +8,6 @@ from typing import List, Tuple, Optional, Dict
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
,
AutoConfig
from
transformers.modeling_utils
import
no_init_weights
from
bloom_inference.cache
import
CacheEntry
from
bloom_inference.pb
import
generate_pb2
from
bloom_inference.shard_model
import
shard_model
,
match_suffix
from
bloom_inference.utils
import
(
...
...
@@ -24,25 +23,35 @@ torch.manual_seed(0)
@
dataclass
class
Batch
:
batch_id
:
int
request
_id
s
:
List
[
in
t
]
requests
:
List
[
generate_pb2
.
Reques
t
]
input_ids
:
Dict
[
str
,
torch
.
Tensor
]
all_input_ids
:
List
[
torch
.
Tensor
]
next_token_choosers
:
List
[
NextTokenChooser
]
stopping_criterias
:
List
[
StoppingCriteria
]
size
:
int
max_sequence_length
:
int
def
to_pb
(
self
):
return
generate_pb2
.
Batch
(
id
=
self
.
batch_id
,
requests
=
self
.
requests
,
size
=
self
.
size
,
max_sequence_length
=
self
.
max_sequence_length
,
)
@
classmethod
def
from_
batch_
pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
AutoTokenizer
,
device
:
torch
.
device
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
AutoTokenizer
,
device
:
torch
.
device
)
->
"Batch"
:
request_ids
=
[]
inputs
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
input_lengths
=
[]
# Parse batch
for
r
in
pb
.
requests
:
request_ids
.
append
(
r
.
id
)
inputs
.
append
(
r
.
inputs
)
input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
NextTokenChooser
(
temperature
=
r
.
parameters
.
temperature
,
...
...
@@ -54,94 +63,93 @@ class Batch:
stopping_criterias
.
append
(
StoppingCriteria
(
max_new_tokens
=
r
.
max_new_tokens
))
input_ids
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
).
to
(
device
)
all_input_ids
=
input_ids
[
"input_ids"
].
unsqueeze
(
-
1
)
# Remove padding from all_input_ids
all_input_ids
=
[
input_ids
.
squeeze
(
0
)[
-
length
:].
unsqueeze
(
-
1
)
for
length
,
input_ids
in
zip
(
input_lengths
,
input_ids
[
"input_ids"
].
split
(
1
,
dim
=
0
)
)
]
return
cls
(
pb
.
id
,
request_ids
,
input_ids
,
all_input_ids
,
next_token_choosers
,
stopping_criterias
,
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
input_ids
=
input_ids
,
all_input_ids
=
all_input_ids
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
pb
.
size
,
max_sequence_length
=
pb
.
max_sequence_length
,
)
@
classmethod
def
from_cache_entry
(
cls
,
cache_entry
:
CacheEntry
)
->
"Batch"
:
return
cls
(
cache_entry
.
batch_id
,
cache_entry
.
request_ids
,
cache_entry
.
input_ids
,
cache_entry
.
all_input_ids
,
cache_entry
.
next_token_choosers
,
cache_entry
.
stopping_criterias
,
)
def
concatenate
(
cls
,
batches
:
List
[
"Batch"
])
->
"Batch"
:
# Used for padding
total_batch_size
=
sum
(
batch
.
size
for
batch
in
batches
)
max_sequence_length
=
max
(
batch
.
max_sequence_length
for
batch
in
batches
)
@
classmethod
def
from_batch_cached_pb
(
cls
,
pb
:
generate_pb2
.
BatchCached
,
cache
)
->
"Batch"
:
if
len
(
pb
.
batch_cached_ids
)
==
1
:
cache_entry
=
cache
.
pop
(
pb
.
batch_cached_ids
[
0
])
if
cache_entry
is
None
:
raise
ValueError
(
f
"Batch ID
{
pb
.
batch_id
}
not found in cache"
)
return
cls
.
from_cache_entry
(
cache_entry
)
total_batch_size
=
pb
.
total_batch_size
max_sequence_length
=
pb
.
max_sequence_length
# Batch attributes
input_ids
=
{
"input_ids"
:
None
,
"attention_mask"
:
None
,
"past_key_values"
:
[]}
request
_id
s
=
[]
requests
=
[]
all_input_ids
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index
=
0
for
i
,
batch_id
in
enumerate
(
pb
.
batch_cached_ids
):
cache_entry
=
cache
.
pop
(
batch_id
)
if
cache_entry
is
None
:
raise
ValueError
(
f
"Batch ID
{
batch_id
}
not found in cache"
)
request_ids
.
extend
(
cache_entry
.
request_ids
)
all_input_ids
.
extend
(
cache_entry
.
all_input_ids
)
next_token_choosers
.
extend
(
cache_entry
.
next_token_choosers
)
stopping_criterias
.
extend
(
cache_entry
.
stopping_criterias
)
batch_size
=
len
(
cache_entry
.
request_ids
)
end_index
=
start_index
+
batch_size
sequence_length
=
max
(
len
(
entry
)
for
entry
in
cache_entry
.
all_input_ids
)
if
input_ids
[
"input_ids"
]
is
None
:
for
i
,
batch
in
enumerate
(
batches
):
requests
.
extend
(
batch
.
requests
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
# Slicing end index for this batch
end_index
=
start_index
+
batch
.
size
# We only concatenate batches that did at least one step
if
batch
.
input_ids
[
"input_ids"
].
shape
[
1
]
>
1
:
raise
ValueError
(
"Batch input_ids should be of shape (batch_size, 1)"
)
# Initialize tensors
if
i
==
0
:
input_ids
[
"input_ids"
]
=
torch
.
empty
(
(
total_batch_size
,
1
),
dtype
=
cache_entry
.
input_ids
[
"input_ids"
].
dtype
,
device
=
cache_entry
.
input_ids
[
"input_ids"
].
device
,
dtype
=
batch
.
input_ids
[
"input_ids"
].
dtype
,
device
=
batch
.
input_ids
[
"input_ids"
].
device
,
)
input_ids
[
"input_ids"
][
start_index
:
end_index
]
=
cache_entry
.
input_ids
[
"input_ids"
]
if
input_ids
[
"attention_mask"
]
is
None
:
input_ids
[
"attention_mask"
]
=
torch
.
zeros
(
(
total_batch_size
,
max_sequence_length
),
dtype
=
cache_entry
.
input_ids
[
"attention_mask"
].
dtype
,
device
=
cache_entry
.
input_ids
[
"attention_mask"
].
device
,
dtype
=
batch
.
input_ids
[
"attention_mask"
].
dtype
,
device
=
batch
.
input_ids
[
"attention_mask"
].
device
,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids
[
"input_ids"
][
start_index
:
end_index
]
=
batch
.
input_ids
[
"input_ids"
]
# We need to slice the attention mask to remove padding from previous steps
input_ids
[
"attention_mask"
][
start_index
:
end_index
,
-
sequence_length
:
]
=
cache_entry
.
input_ids
[
"attention_mask"
][:,
-
sequence_length
:]
start_index
:
end_index
,
-
batch
.
max_
sequence_length
:
]
=
batch
.
input_ids
[
"attention_mask"
][:,
-
batch
.
max_
sequence_length
:]
for
j
,
past
in
enumerate
(
cache_entry
.
input_ids
[
"past_key_values"
]):
# TODO: this could be done without the views by using indices
for
j
,
past
in
enumerate
(
batch
.
input_ids
[
"past_key_values"
]):
past_keys
=
past
[
0
]
past_values
=
past
[
1
]
_
,
head_dim
,
padded_sequence_length
=
past_keys
.
shape
# Reshape the tensors to make slicing easier
past_keys
=
past_keys
.
view
(
batch
_
size
,
-
1
,
head_dim
,
padded_sequence_length
batch
.
size
,
-
1
,
head_dim
,
padded_sequence_length
)
past_values
=
past_values
.
view
(
batch
_
size
,
-
1
,
padded_sequence_length
,
head_dim
batch
.
size
,
-
1
,
padded_sequence_length
,
head_dim
)
num_heads
=
past_keys
.
shape
[
1
]
# Initialize tensors
# This will run only once per layer
if
j
==
len
(
input_ids
[
"past_key_values"
]):
padded_past_keys
=
torch
.
zeros
(
(
...
...
@@ -167,15 +175,17 @@ class Batch:
[
padded_past_keys
,
padded_past_values
]
)
# We slice the past keys and values to remove the padding from previous batches
input_ids
[
"past_key_values"
][
j
][
0
][
start_index
:
end_index
,
:,
:,
-
(
sequence_length
-
1
):
]
=
past_keys
[:,
:,
:,
-
(
sequence_length
-
1
):]
start_index
:
end_index
,
:,
:,
-
(
batch
.
max_
sequence_length
-
1
)
:
]
=
past_keys
[:,
:,
:,
-
(
batch
.
max_
sequence_length
-
1
)
:]
input_ids
[
"past_key_values"
][
j
][
1
][
start_index
:
end_index
,
:,
-
(
sequence_length
-
1
):,
:
]
=
past_values
[:,
:,
-
(
sequence_length
-
1
):,
:]
start_index
:
end_index
,
:,
-
(
batch
.
max_
sequence_length
-
1
)
:,
:
]
=
past_values
[:,
:,
-
(
batch
.
max_
sequence_length
-
1
)
:,
:]
if
(
i
+
1
)
==
len
(
pb
.
batch_cached_ids
):
# If we are on the last batch, we need to reshape the tensors
if
(
i
+
1
)
==
len
(
batches
):
input_ids
[
"past_key_values"
][
j
][
0
]
=
input_ids
[
"past_key_values"
][
j
][
0
].
view
(
total_batch_size
*
num_heads
,
head_dim
,
-
1
)
...
...
@@ -183,27 +193,27 @@ class Batch:
j
][
1
].
view
(
total_batch_size
*
num_heads
,
-
1
,
head_dim
)
start_index
+=
batch_size
assert
pb
.
request_ids
==
request_ids
start_index
+=
batch
.
size
return
cls
(
pb
.
id
,
request_ids
,
input_ids
,
all_input_ids
,
next_token_choosers
,
stopping_criterias
,
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
input_ids
=
input_ids
,
all_input_ids
=
all_input_ids
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
total_batch_size
,
max_sequence_length
=
max_sequence_length
,
)
@
dataclass
class
Finished
Generat
ion
:
request
_id
:
st
r
class
Generat
edText
:
request
:
generate_pb2
.
Reque
st
output
:
str
def
to_pb
(
self
)
->
generate_pb2
.
Finished
Generat
ion
:
return
generate_pb2
.
FinishedGeneration
(
id
=
self
.
request
_id
,
output
=
self
.
output
)
def
to_pb
(
self
)
->
generate_pb2
.
Generat
edText
:
return
generate_pb2
.
GeneratedText
(
request
=
self
.
request
,
output
=
self
.
output
)
class
BLOOM
:
...
...
@@ -229,25 +239,28 @@ class BLOOM:
)
def
generate_token
(
self
,
batch
:
Batch
)
->
Tuple
[
List
[
Finished
Generat
ion
],
Optional
[
CacheEntry
]]:
self
,
batch
:
Batch
)
->
Tuple
[
List
[
Generat
edText
],
Optional
[
Batch
]]:
with
torch
.
no_grad
():
outputs
=
self
.
forward
(
**
batch
.
input_ids
)
# List of indices to cache
cache_indices
=
[]
cache_past_indices
=
[]
next_batch_keep_indices
=
[]
next_batch_past_keep_indices
=
[]
# New input_ids for next forward
next_batch_input_ids
=
[]
next_batch_all_input_ids
=
[]
# New input_ids for next forward; keep in cache
cache_next_input_ids
=
[]
cache_all_input_ids
=
[]
next_batch_size
=
0
next_batch_max_sequence_length
=
0
# Finished requests
finished_generations
:
List
[
Finished
Generat
ion
]
=
[]
generated_texts
:
List
[
Generat
edText
]
=
[]
# Zipped iterator
iterator
=
zip
(
batch
.
request
_id
s
,
batch
.
requests
,
outputs
.
logits
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
...
...
@@ -256,11 +269,11 @@ class BLOOM:
# For each member of the batch
for
i
,
(
request
_id
,
logits
,
next_token_chooser
,
stopping_criteria
,
all_tokens
,
request
,
logits
,
next_token_chooser
,
stopping_criteria
,
all_tokens
,
)
in
enumerate
(
iterator
):
# Select next token
next_token
=
next_token_chooser
(
all_tokens
,
logits
.
unsqueeze
(
0
)[:,
-
1
])
...
...
@@ -274,64 +287,75 @@ class BLOOM:
output
=
self
.
tokenizer
.
decode
(
all_tokens
.
squeeze
(
-
1
),
skip_special_tokens
=
True
)
# Add to the list of finished generations with the original request
id
finished_generation
s
.
append
(
Finished
Generat
ion
(
request
_id
,
output
))
#
must be adde
d to the
ca
ch
e
# Add to the list of finished generations with the original request
generated_text
s
.
append
(
Generat
edText
(
request
,
output
))
#
ad
d to the
next bat
ch
else
:
cache_indices
.
append
(
i
)
cache_past_indices
.
extend
([
j
for
j
in
range
(
i
*
self
.
num_heads
,
(
i
+
1
)
*
self
.
num_heads
)])
cache_next_input_ids
.
append
(
next_token
)
cache_all_input_ids
.
append
(
all_tokens
)
next_batch_keep_indices
.
append
(
i
)
# past_key_values is of shape [batch_size * num_heads, ...]
# so we need to take into account the `num_heads` stride here
next_batch_past_keep_indices
.
extend
(
[
j
for
j
in
range
(
i
*
self
.
num_heads
,
(
i
+
1
)
*
self
.
num_heads
)]
)
next_batch_input_ids
.
append
(
next_token
)
next_batch_all_input_ids
.
append
(
all_tokens
)
next_batch_size
+=
1
next_batch_max_sequence_length
=
max
(
next_batch_max_sequence_length
,
len
(
all_tokens
)
)
#
No cache is needed, w
e finished all generations in the batch
if
not
cache
_indices
:
return
finished_generation
s
,
None
#
W
e finished all generations in the batch
; there is no next batch
if
not
next_batch_keep
_indices
:
return
generated_text
s
,
None
# If we finished at least one generation
ca
ch
e
_input_ids
=
{
"input_ids"
:
torch
.
cat
(
cache_next
_input_ids
,
dim
=
0
)}
if
finished_generation
s
:
next_bat
ch_input_ids
=
{
"input_ids"
:
torch
.
cat
(
next_batch
_input_ids
,
dim
=
0
)}
if
generated_text
s
:
# Apply indices to attention mask, past key values and other items that need to be cached
ca
ch
e
_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
][
cache
_indices
next_bat
ch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
][
next_batch_keep
_indices
]
cache_input_ids
[
"past_key_values"
]
=
[
(
keys
[
cache_past_indices
],
values
[
cache_past_indices
])
next_batch_input_ids
[
"past_key_values"
]
=
[
(
keys
[
next_batch_past_keep_indices
],
values
[
next_batch_past_keep_indices
],
)
for
keys
,
values
in
outputs
[
"past_key_values"
]
]
ca
ch
e
_request
_id
s
=
[
batch
.
request
_id
s
[
i
]
for
i
in
cache
_indices
]
ca
ch
e
_next_token_choosers
=
[
batch
.
next_token_choosers
[
i
]
for
i
in
cache
_indices
next_bat
ch_requests
=
[
batch
.
requests
[
i
]
for
i
in
next_batch_keep
_indices
]
next_bat
ch_next_token_choosers
=
[
batch
.
next_token_choosers
[
i
]
for
i
in
next_batch_keep
_indices
]
ca
ch
e
_stopping_criterias
=
[
batch
.
stopping_criterias
[
i
]
for
i
in
cache
_indices
next_bat
ch_stopping_criterias
=
[
batch
.
stopping_criterias
[
i
]
for
i
in
next_batch_keep
_indices
]
else
:
ca
ch
e
_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
]
ca
ch
e
_input_ids
[
"past_key_values"
]
=
outputs
[
"past_key_values"
]
ca
ch
e
_request
_id
s
=
batch
.
request
_id
s
ca
ch
e
_next_token_choosers
=
batch
.
next_token_choosers
ca
ch
e
_stopping_criterias
=
batch
.
stopping_criterias
next_bat
ch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
]
next_bat
ch_input_ids
[
"past_key_values"
]
=
outputs
[
"past_key_values"
]
next_bat
ch_requests
=
batch
.
requests
next_bat
ch_next_token_choosers
=
batch
.
next_token_choosers
next_bat
ch_stopping_criterias
=
batch
.
stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
ca
ch
e
_input_ids
[
"attention_mask"
]
=
torch
.
cat
(
next_bat
ch_input_ids
[
"attention_mask"
]
=
torch
.
cat
(
[
cache_input_ids
[
"attention_mask"
],
torch
.
ones
((
cache_input_ids
[
"attention_mask"
].
shape
[
0
],
1
)).
to
(
cache_input_ids
[
"attention_mask"
].
device
),
next_batch_input_ids
[
"attention_mask"
],
torch
.
ones
((
next_batch_size
,
1
)).
to
(
self
.
device
),
],
dim
=
1
,
)
cache_entry
=
CacheEntry
(
batch
.
batch_id
,
cache_request_ids
,
cache_input_ids
,
cache_all_input_ids
,
cache_next_token_choosers
,
cache_stopping_criterias
,
next_batch
=
Batch
(
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
input_ids
=
next_batch_input_ids
,
all_input_ids
=
next_batch_all_input_ids
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
size
=
next_batch_size
,
max_sequence_length
=
next_batch_max_sequence_length
,
)
return
finished_generations
,
cache_entry
return
generated_texts
,
next_batch
class
BLOOMSharded
(
BLOOM
):
...
...
server/bloom_inference/server.py
View file @
4c693e65
...
...
@@ -10,7 +10,7 @@ from bloom_inference.model import BLOOM, Batch, BLOOMSharded
from
bloom_inference.pb
import
generate_pb2_grpc
,
generate_pb2
class
TextGeneration
(
generate_pb2_grpc
.
TextGenerationServicer
):
class
TextGeneration
Service
(
generate_pb2_grpc
.
TextGenerationService
Service
r
):
def
__init__
(
self
,
model
:
BLOOM
,
cache
:
Cache
,
server_urls
:
List
[
str
]):
self
.
cache
=
cache
self
.
model
=
model
...
...
@@ -21,32 +21,90 @@ class TextGeneration(generate_pb2_grpc.TextGenerationServicer):
async
def
ClearCache
(
self
,
request
,
context
):
self
.
cache
.
clear
()
return
generate_pb2
.
Empty
()
return
generate_pb2
.
ClearCacheResponse
()
async
def
Generate
(
self
,
request
,
context
):
batch
=
Batch
.
from_
batch_
pb
(
request
,
self
.
model
.
tokenizer
,
self
.
model
.
device
)
finished_generations
,
cache_entry
=
self
.
model
.
generate_token
(
batch
)
self
.
cache
.
set
(
cache_entry
)
return
generate_pb2
.
Response
(
finished
=
[
finished_generation
.
to_pb
()
for
finished_generation
in
finished_generation
s
batch
=
Batch
.
from_pb
(
request
.
batch
,
self
.
model
.
tokenizer
,
self
.
model
.
device
)
generated_texts
,
next_batch
=
self
.
model
.
generate_token
(
batch
)
self
.
cache
.
set
(
next_batch
)
return
generate_pb2
.
GenerateResponse
(
generated_texts
=
[
generated_text
.
to_pb
()
for
generated_text
in
generated_text
s
],
cache_entry
=
cache_entry
.
to_pb
()
if
cache_entry
else
None
,
batch
=
next_batch
.
to_pb
()
if
next_batch
else
None
,
)
async
def
GenerateWithCache
(
self
,
request
,
context
):
batch
=
Batch
.
from_batch_cached_pb
(
request
,
self
.
cache
)
finished_generations
,
cache_entry
=
self
.
model
.
generate_token
(
batch
)
self
.
cache
.
set
(
cache_entry
)
return
generate_pb2
.
Response
(
finished
=
[
finished_generation
.
to_pb
()
for
finished_generation
in
finished_generations
if
len
(
request
.
batches
)
==
0
:
raise
ValueError
(
"Must provide at least one batch"
)
batches
=
[]
for
batch_pb
in
request
.
batches
:
batch
=
self
.
cache
.
pop
(
batch_pb
.
id
)
if
batch
is
None
:
raise
ValueError
(
f
"Batch ID
{
batch_pb
.
id
}
not found in cache."
)
batches
.
append
(
batch
)
if
len
(
batches
)
>
1
:
batch
=
Batch
.
concatenate
(
batches
)
else
:
batch
=
batches
[
0
]
generated_texts
,
next_batch
=
self
.
model
.
generate_token
(
batch
)
self
.
cache
.
set
(
next_batch
)
return
generate_pb2
.
GenerateWithCacheResponse
(
generated_texts
=
[
generated_text
.
to_pb
()
for
generated_text
in
generated_texts
],
batch
=
next_batch
.
to_pb
()
if
next_batch
else
None
,
)
async
def
GenerateUntilFinished
(
self
,
request
,
context
):
batch
=
Batch
.
from_pb
(
request
.
batch
,
self
.
model
.
tokenizer
,
self
.
model
.
device
)
generated_texts
=
[]
while
not
generated_texts
:
generated_texts
,
next_batch
=
self
.
model
.
generate_token
(
batch
)
batch
=
next_batch
self
.
cache
.
set
(
next_batch
)
return
generate_pb2
.
GenerateUntilFinishedResponse
(
generated_texts
=
[
generated_text
.
to_pb
()
for
generated_text
in
generated_texts
],
batch
=
next_batch
.
to_pb
()
if
next_batch
else
None
,
)
async
def
GenerateUntilFinishedWithCache
(
self
,
request
,
context
):
if
len
(
request
.
batches
)
==
0
:
raise
ValueError
(
"Must provide at least one batch"
)
batches
=
[]
for
batch_pb
in
request
.
batches
:
batch
=
self
.
cache
.
pop
(
batch_pb
.
id
)
if
batch
is
None
:
raise
ValueError
(
f
"Batch ID
{
batch_pb
.
id
}
not found in cache."
)
batches
.
append
(
batch
)
if
len
(
batches
)
>
1
:
batch
=
Batch
.
concatenate
(
batches
)
else
:
batch
=
batches
[
0
]
generated_texts
=
[]
while
not
generated_texts
:
generated_texts
,
next_batch
=
self
.
model
.
generate_token
(
batch
)
batch
=
next_batch
self
.
cache
.
set
(
next_batch
)
return
generate_pb2
.
GenerateUntilFinishedWithCacheResponse
(
generated_texts
=
[
generated_text
.
to_pb
()
for
generated_text
in
generated_texts
],
cache_entry
=
cache_entry
.
to_pb
()
if
cache_entry
else
None
,
batch
=
next_batch
.
to_pb
()
if
next_batch
else
None
,
)
...
...
@@ -71,11 +129,11 @@ def serve(model_name, sharded, shard_directory):
server_urls
=
[
local_url
]
server
=
aio
.
server
()
generate_pb2_grpc
.
add_TextGenerationServicer_to_server
(
TextGeneration
(
model
,
Cache
(),
server_urls
),
server
generate_pb2_grpc
.
add_TextGenerationService
Service
r_to_server
(
TextGeneration
Service
(
model
,
Cache
(),
server_urls
),
server
)
SERVICE_NAMES
=
(
generate_pb2
.
DESCRIPTOR
.
services_by_name
[
"TextGeneration"
].
full_name
,
generate_pb2
.
DESCRIPTOR
.
services_by_name
[
"TextGeneration
Service
"
].
full_name
,
reflection
.
SERVICE_NAME
,
)
reflection
.
enable_server_reflection
(
SERVICE_NAMES
,
server
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment