Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
f16f2f5a
Commit
f16f2f5a
authored
Oct 18, 2022
by
Olivier Dehaene
Committed by
OlivierDehaene
Oct 20, 2022
Browse files
v0.1.0
parent
92c1ecd0
Changes
36
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
600 additions
and
281 deletions
+600
-281
router/src/lib.rs
router/src/lib.rs
+62
-2
router/src/main.rs
router/src/main.rs
+45
-7
router/src/server.rs
router/src/server.rs
+97
-69
router/src/validation.rs
router/src/validation.rs
+100
-31
run.sh
run.sh
+0
-30
rust-toolchain.toml
rust-toolchain.toml
+0
-0
server/.gitignore
server/.gitignore
+155
-0
server/Makefile
server/Makefile
+16
-5
server/bloom_inference/cli.py
server/bloom_inference/cli.py
+33
-23
server/bloom_inference/model.py
server/bloom_inference/model.py
+24
-12
server/bloom_inference/pb/.gitignore
server/bloom_inference/pb/.gitignore
+1
-1
server/bloom_inference/prepare_weights.py
server/bloom_inference/prepare_weights.py
+50
-47
server/bloom_inference/server.py
server/bloom_inference/server.py
+16
-51
server/bloom_inference/utils.py
server/bloom_inference/utils.py
+0
-1
server/poetry.lock
server/poetry.lock
+1
-1
server/pyproject.toml
server/pyproject.toml
+0
-1
No files found.
router/src/lib.rs
View file @
f16f2f5a
/// Text Generation Inference Webserver
mod
batcher
;
mod
batcher
;
mod
db
;
mod
db
;
mod
validation
;
pub
mod
server
;
pub
mod
server
;
mod
validation
;
use
db
::{
Db
,
Entry
};
use
batcher
::
Batcher
;
use
batcher
::
Batcher
;
use
db
::{
Db
,
Entry
};
use
serde
::{
Deserialize
,
Serialize
};
use
validation
::
Validation
;
use
validation
::
Validation
;
#[derive(Clone,
Debug,
Deserialize)]
pub
(
crate
)
struct
GenerateParameters
{
#[serde(default
=
"default_temperature"
)]
pub
temperature
:
f32
,
#[serde(default
=
"default_top_k"
)]
pub
top_k
:
i32
,
#[serde(default
=
"default_top_p"
)]
pub
top_p
:
f32
,
#[serde(default
=
"default_do_sample"
)]
pub
do_sample
:
bool
,
#[serde(default
=
"default_max_new_tokens"
)]
pub
max_new_tokens
:
u32
,
}
fn
default_temperature
()
->
f32
{
1.0
}
fn
default_top_k
()
->
i32
{
0
}
fn
default_top_p
()
->
f32
{
1.0
}
fn
default_do_sample
()
->
bool
{
false
}
fn
default_max_new_tokens
()
->
u32
{
20
}
fn
default_parameters
()
->
GenerateParameters
{
GenerateParameters
{
temperature
:
default_temperature
(),
top_k
:
default_top_k
(),
top_p
:
default_top_p
(),
do_sample
:
default_do_sample
(),
max_new_tokens
:
default_max_new_tokens
(),
}
}
#[derive(Clone,
Debug,
Deserialize)]
pub
(
crate
)
struct
GenerateRequest
{
pub
inputs
:
String
,
#[serde(default
=
"default_parameters"
)]
pub
parameters
:
GenerateParameters
,
}
#[derive(Serialize)]
pub
(
crate
)
struct
GeneratedText
{
pub
generated_text
:
String
,
}
pub
(
crate
)
type
GenerateResponse
=
Vec
<
GeneratedText
>
;
router/src/main.rs
View file @
f16f2f5a
/// Text Generation Inference webserver entrypoint
use
bloom_inference_client
::
ShardedClient
;
use
bloom_inference_client
::
ShardedClient
;
use
clap
::
Parser
;
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
SocketAddr
};
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
SocketAddr
};
use
std
::
time
::
Duration
;
use
text_generation_router
::
server
;
use
text_generation_router
::
server
;
use
tokenizers
::
Tokenizer
;
use
tokenizers
::
Tokenizer
;
use
clap
::
Parser
;
/// App Configuration
/// App Configuration
#[derive(Parser,
Debug)]
#[derive(Parser,
Debug)]
#[clap(author,
version,
about,
long_about
=
None)]
#[clap(author,
version,
about,
long_about
=
None)]
struct
Args
{
struct
Args
{
#[clap(default_value
=
"32"
,
long,
short,
env)]
#[clap(default_value
=
"128"
,
long,
env)]
max_concurrent_requests
:
usize
,
#[clap(default_value
=
"1000"
,
long,
env)]
max_input_length
:
usize
,
#[clap(default_value
=
"32"
,
long,
env)]
max_batch_size
:
usize
,
max_batch_size
:
usize
,
#[clap(default_value
=
"5"
,
long,
env)]
max_waiting_time
:
u64
,
#[clap(default_value
=
"3000"
,
long,
short,
env)]
#[clap(default_value
=
"3000"
,
long,
short,
env)]
port
:
u16
,
port
:
u16
,
#[clap(default_value
=
"/tmp/bloom-inference-0"
,
long,
env)]
#[clap(default_value
=
"/tmp/bloom-inference-0"
,
long,
env)]
shard_uds_path
:
String
,
master_
shard_uds_path
:
String
,
#[clap(default_value
=
"bigscience/bloom"
,
long,
env)]
#[clap(default_value
=
"bigscience/bloom"
,
long,
env)]
tokenizer_name
:
String
,
tokenizer_name
:
String
,
#[clap(default_value
=
"2"
,
long,
env)]
validation_workers
:
usize
,
}
}
fn
main
()
->
Result
<
(),
std
::
io
::
Error
>
{
fn
main
()
->
Result
<
(),
std
::
io
::
Error
>
{
// Get args
// Get args
let
args
=
Args
::
parse
();
let
args
=
Args
::
parse
();
// Pattern match configuration
// Pattern match configuration
let
Args
{
let
Args
{
max_concurrent_requests
,
max_input_length
,
max_batch_size
,
max_batch_size
,
max_waiting_time
,
port
,
port
,
shard_uds_path
,
master_
shard_uds_path
,
tokenizer_name
,
tokenizer_name
,
validation_workers
,
}
=
args
;
}
=
args
;
if
validation_workers
==
1
{
panic!
(
"validation_workers must be > 0"
);
}
let
max_waiting_time
=
Duration
::
from_secs
(
max_waiting_time
);
// Download and instantiate tokenizer
// This will only be used to validate payloads
//
// We need to download it outside of the Tokio runtime
let
tokenizer
=
Tokenizer
::
from_pretrained
(
tokenizer_name
,
None
)
.unwrap
();
let
tokenizer
=
Tokenizer
::
from_pretrained
(
tokenizer_name
,
None
)
.unwrap
();
// Launch Tokio runtime
tokio
::
runtime
::
Builder
::
new_multi_thread
()
tokio
::
runtime
::
Builder
::
new_multi_thread
()
.enable_all
()
.enable_all
()
.build
()
.build
()
...
@@ -39,18 +63,32 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -39,18 +63,32 @@ fn main() -> Result<(), std::io::Error> {
.block_on
(
async
{
.block_on
(
async
{
tracing_subscriber
::
fmt
::
init
();
tracing_subscriber
::
fmt
::
init
();
let
sharded_client
=
ShardedClient
::
connect_uds
(
shard_uds_path
)
// Instantiate sharded client from the master unix socket
let
sharded_client
=
ShardedClient
::
connect_uds
(
master_shard_uds_path
)
.await
.await
.expect
(
"Could not connect to server"
);
.expect
(
"Could not connect to server"
);
// Clear the cache; useful if the webserver rebooted
sharded_client
sharded_client
.clear_cache
()
.clear_cache
()
.await
.await
.expect
(
"Unable to clear cache"
);
.expect
(
"Unable to clear cache"
);
tracing
::
info!
(
"Connected"
);
tracing
::
info!
(
"Connected"
);
// Binds on localhost
let
addr
=
SocketAddr
::
new
(
IpAddr
::
V4
(
Ipv4Addr
::
new
(
0
,
0
,
0
,
0
)),
port
);
let
addr
=
SocketAddr
::
new
(
IpAddr
::
V4
(
Ipv4Addr
::
new
(
0
,
0
,
0
,
0
)),
port
);
server
::
run
(
max_batch_size
,
sharded_client
,
tokenizer
,
addr
)
.await
;
// Run server
server
::
run
(
max_concurrent_requests
,
max_input_length
,
max_batch_size
,
max_waiting_time
,
sharded_client
,
tokenizer
,
validation_workers
,
addr
,
)
.await
;
Ok
(())
Ok
(())
})
})
}
}
router/src/server.rs
View file @
f16f2f5a
use
crate
::{
Batcher
,
Validation
};
use
crate
::{
Batcher
,
GenerateParameters
,
GenerateRequest
,
GenerateResponse
,
GeneratedText
,
Validation
,
};
use
axum
::
extract
::
Extension
;
use
axum
::
extract
::
Extension
;
use
axum
::
http
::
StatusCode
;
use
axum
::
http
::
StatusCode
;
use
axum
::
routing
::{
get
,
post
};
use
axum
::
routing
::{
get
,
post
};
use
axum
::{
Json
,
Router
};
use
axum
::{
Json
,
Router
};
use
bloom_inference_client
::
ShardedClient
;
use
bloom_inference_client
::
ShardedClient
;
use
serde
::
Deserialize
;
use
std
::
net
::
SocketAddr
;
use
std
::
net
::
SocketAddr
;
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
tokenizers
::
Tokenizer
;
use
tokenizers
::
Tokenizer
;
use
tokio
::
signal
;
use
tokio
::
sync
::
Semaphore
;
use
tokio
::
time
::
Instant
;
use
tokio
::
time
::
Instant
;
use
tracing
::
instrument
;
use
tracing
::
instrument
;
#[derive(Clone,
Debug,
Deserialize)]
// Server shared state
pub
(
crate
)
struct
GenerateParameters
{
#[derive(Clone)]
#[serde(default
=
"default_temperature"
)]
struct
ServerState
{
pub
temperature
:
f32
,
validation
:
Validation
,
#[serde(default
=
"default_top_k"
)]
batcher
:
Batcher
,
pub
top_k
:
i32
,
limit_concurrent_requests
:
Arc
<
Semaphore
>
,
#[serde(default
=
"default_top_p"
)]
pub
top_p
:
f32
,
#[serde(default
=
"default_do_sample"
)]
pub
do_sample
:
bool
,
#[serde(default
=
"default_max_new_tokens"
)]
pub
max_new_tokens
:
u32
,
}
fn
default_temperature
()
->
f32
{
1.0
}
fn
default_top_k
()
->
i32
{
0
}
fn
default_top_p
()
->
f32
{
1.0
}
fn
default_do_sample
()
->
bool
{
false
}
fn
default_max_new_tokens
()
->
u32
{
20
}
fn
default_parameters
()
->
GenerateParameters
{
GenerateParameters
{
temperature
:
default_temperature
(),
top_k
:
default_top_k
(),
top_p
:
default_top_p
(),
do_sample
:
default_do_sample
(),
max_new_tokens
:
default_max_new_tokens
(),
}
}
#[derive(Clone,
Debug,
Deserialize)]
pub
(
crate
)
struct
GenerateRequest
{
pub
inputs
:
String
,
#[serde(default
=
"default_parameters"
)]
pub
parameters
:
GenerateParameters
,
}
}
/// Health check method
#[instrument(skip(state),
fields(time,
time_per_token))]
#[instrument(skip(state),
fields(time,
time_per_token))]
async
fn
liveness
(
state
:
Extension
<
ServerState
>
)
->
Result
<
(),
(
StatusCode
,
String
)
>
{
async
fn
health
(
state
:
Extension
<
ServerState
>
)
->
Result
<
(),
(
StatusCode
,
String
)
>
{
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check.
// What we should do instead if check if the gRPC channels are still healthy.
// Limit concurrent requests by acquiring a permit from the semaphore
let
_
permit
=
state
.limit_concurrent_requests
.try_acquire
()
.map_err
(|
_
|
{
(
StatusCode
::
TOO_MANY_REQUESTS
,
"Model is overloaded"
.to_string
(),
)
})
?
;
// Send a small inference request
state
state
.batcher
.batcher
.infer
(
.infer
(
...
@@ -82,23 +58,35 @@ async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, Stri
...
@@ -82,23 +58,35 @@ async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, Stri
Ok
(())
Ok
(())
}
}
/// Generate method
#[instrument(skip(state),
fields(time,
time_per_token))]
#[instrument(skip(state),
fields(time,
time_per_token))]
async
fn
generate
(
async
fn
generate
(
state
:
Extension
<
ServerState
>
,
state
:
Extension
<
ServerState
>
,
req
:
Json
<
GenerateRequest
>
,
req
:
Json
<
GenerateRequest
>
,
)
->
Result
<
Json
<
serde_json
::
Valu
e
>
,
(
StatusCode
,
String
)
>
{
)
->
Result
<
Json
<
GenerateRespons
e
>
,
(
StatusCode
,
String
)
>
{
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
// Limit concurrent requests by acquiring a permit from the semaphore
let
_
permit
=
state
.limit_concurrent_requests
.try_acquire
()
.map_err
(|
_
|
{
(
StatusCode
::
TOO_MANY_REQUESTS
,
"Model is overloaded"
.to_string
(),
)
})
?
;
// Validate request
let
(
input_length
,
validated_request
)
=
state
let
(
input_length
,
validated_request
)
=
state
.validation
.validation
// FIXME: can't we get rid of the cloning here??
.validate
(
GenerateRequest
{
.validate
(
GenerateRequest
{
inputs
:
req
.inputs
.clone
(),
inputs
:
req
.inputs
.clone
(),
parameters
:
req
.parameters
.clone
(),
parameters
:
req
.parameters
.clone
(),
})
})
.await
?
;
.await
?
;
// Inference
let
generated_text
=
state
.batcher
.infer
(
input_length
,
validated_request
)
.await
?
;
let
generated_text
=
state
.batcher
.infer
(
input_length
,
validated_request
)
.await
?
;
// Tracing metadata
tracing
::
Span
::
current
()
.record
(
"time"
,
format!
(
"{:?}"
,
start
.elapsed
()));
tracing
::
Span
::
current
()
.record
(
"time"
,
format!
(
"{:?}"
,
start
.elapsed
()));
tracing
::
Span
::
current
()
.record
(
tracing
::
Span
::
current
()
.record
(
"time_per_token"
,
"time_per_token"
,
...
@@ -106,31 +94,71 @@ async fn generate(
...
@@ -106,31 +94,71 @@ async fn generate(
);
);
tracing
::
info!
(
"response: {}"
,
generated_text
);
tracing
::
info!
(
"response: {}"
,
generated_text
);
Ok
(
Json
(
serde_json
::
json!
({
// Send response
"g
enerated
_t
ext
"
:
generated_text
,
let
response
=
vec!
[
G
enerated
T
ext
{
generated_text
}];
})
))
Ok
(
Json
(
response
))
}
}
#[derive(Clone)]
/// Serving method
struct
ServerState
{
#[allow(clippy::too_many_arguments)]
validation
:
Validation
,
pub
async
fn
run
(
batcher
:
Batcher
,
max_concurrent_requests
:
usize
,
}
max_input_length
:
usize
,
max_batch_size
:
usize
,
pub
async
fn
run
(
max_batch_size
:
usize
,
client
:
ShardedClient
,
tokenizer
:
Tokenizer
,
addr
:
SocketAddr
)
{
max_waiting_time
:
Duration
,
let
batcher
=
Batcher
::
new
(
client
,
max_batch_size
);
client
:
ShardedClient
,
let
validation
=
Validation
::
new
(
tokenizer
);
tokenizer
:
Tokenizer
,
validation_workers
:
usize
,
let
shared_state
=
ServerState
{
validation
,
batcher
};
addr
:
SocketAddr
,
)
{
// Create state
let
batcher
=
Batcher
::
new
(
client
,
max_batch_size
,
max_waiting_time
);
let
validation
=
Validation
::
new
(
validation_workers
,
tokenizer
,
max_input_length
);
let
shared_state
=
ServerState
{
validation
,
batcher
,
limit_concurrent_requests
:
Arc
::
new
(
Semaphore
::
new
(
max_concurrent_requests
)),
};
// Create router
let
app
=
Router
::
new
()
let
app
=
Router
::
new
()
.route
(
"/generate"
,
post
(
generate
))
.route
(
"/generate"
,
post
(
generate
))
.layer
(
Extension
(
shared_state
.clone
()))
.layer
(
Extension
(
shared_state
.clone
()))
.route
(
"/health"
,
get
(
liveness
))
.route
(
"/health"
,
get
(
health
))
.layer
(
Extension
(
shared_state
.clone
()));
.layer
(
Extension
(
shared_state
.clone
()));
// Run server
axum
::
Server
::
bind
(
&
addr
)
axum
::
Server
::
bind
(
&
addr
)
.serve
(
app
.into_make_service
())
.serve
(
app
.into_make_service
())
// Wait until all requests are finished to shut down
.with_graceful_shutdown
(
shutdown_signal
())
.await
.await
.unwrap
();
.unwrap
();
}
}
/// Shutdown signal handler
async
fn
shutdown_signal
()
{
let
ctrl_c
=
async
{
signal
::
ctrl_c
()
.await
.expect
(
"failed to install Ctrl+C handler"
);
};
#[cfg(unix)]
let
terminate
=
async
{
signal
::
unix
::
signal
(
signal
::
unix
::
SignalKind
::
terminate
())
.expect
(
"failed to install signal handler"
)
.recv
()
.await
;
};
#[cfg(not(unix))]
let
terminate
=
std
::
future
::
pending
::
<
()
>
();
tokio
::
select!
{
_
=
ctrl_c
=>
{},
_
=
terminate
=>
{},
}
tracing
::
info!
(
"signal received, starting graceful shutdown"
);
}
router/src/validation.rs
View file @
f16f2f5a
use
crate
::
server
::
GenerateRequest
;
/// Payload validation logic
use
crate
::
GenerateRequest
;
use
axum
::
http
::
StatusCode
;
use
axum
::
http
::
StatusCode
;
use
thiserror
::
Error
;
use
thiserror
::
Error
;
use
tokenizers
::
tokenizer
::
Tokenizer
;
use
tokenizers
::
tokenizer
::
Tokenizer
;
use
tokenizers
::{
DecoderWrapper
,
ModelWrapper
,
NormalizerWrapper
,
PostProcessorWrapper
,
PreTokenizerWrapper
,
TokenizerImpl
,
};
use
tokio
::
sync
::{
mpsc
,
oneshot
};
use
tokio
::
sync
::{
mpsc
,
oneshot
};
#[derive(Error,
Debug)]
/// Validation
pub
enum
ValidationError
{
#[error(
"Temperature must be strictly positive"
)]
Temperature
,
#[error(
"Top p must be <= 0.0 or > 1.0"
)]
TopP
,
#[error(
"Top k must be strictly positive"
)]
TopK
,
#[error(
"Max New Tokens must be < 512"
)]
MaxNewTokens
,
#[error(
"Inputs must have less than 1000 tokens. Given: {0}"
)]
InputLength
(
usize
),
}
impl
From
<
ValidationError
>
for
(
StatusCode
,
String
)
{
fn
from
(
err
:
ValidationError
)
->
Self
{
(
StatusCode
::
BAD_REQUEST
,
err
.to_string
())
}
}
type
ValidationRequest
=
(
GenerateRequest
,
oneshot
::
Sender
<
Result
<
(
usize
,
GenerateRequest
),
ValidationError
>>
,
);
#[derive(Debug,
Clone)]
#[derive(Debug,
Clone)]
pub
struct
Validation
{
pub
struct
Validation
{
/// Channel to communicate with the background validation task
sender
:
mpsc
::
Sender
<
ValidationRequest
>
,
sender
:
mpsc
::
Sender
<
ValidationRequest
>
,
}
}
impl
Validation
{
impl
Validation
{
pub
(
crate
)
fn
new
(
tokenizer
:
Tokenizer
)
->
Self
{
pub
(
crate
)
fn
new
(
workers
:
usize
,
tokenizer
:
Tokenizer
,
max_input_length
:
usize
)
->
Self
{
// Crate channel
let
(
validation_sender
,
validation_receiver
)
=
mpsc
::
channel
(
128
);
let
(
validation_sender
,
validation_receiver
)
=
mpsc
::
channel
(
128
);
tokio
::
spawn
(
validation_task
(
tokenizer
,
validation_receiver
));
// Launch background validation task
tokio
::
spawn
(
validation_task
(
workers
,
tokenizer
,
max_input_length
,
validation_receiver
,
));
Self
{
Self
{
sender
:
validation_sender
,
sender
:
validation_sender
,
}
}
}
}
/// Validate a payload and get the number of tokens in the input
pub
(
crate
)
async
fn
validate
(
pub
(
crate
)
async
fn
validate
(
&
self
,
&
self
,
request
:
GenerateRequest
,
request
:
GenerateRequest
,
)
->
Result
<
(
usize
,
GenerateRequest
),
ValidationError
>
{
)
->
Result
<
(
usize
,
GenerateRequest
),
ValidationError
>
{
// Create response channel
let
(
sender
,
receiver
)
=
oneshot
::
channel
();
let
(
sender
,
receiver
)
=
oneshot
::
channel
();
// Send request to the background validation task
// Unwrap is safe here
self
.sender
.send
((
request
,
sender
))
.await
.unwrap
();
self
.sender
.send
((
request
,
sender
))
.await
.unwrap
();
// Await on response channel
// Unwrap is safe here
receiver
.await
.unwrap
()
receiver
.await
.unwrap
()
}
}
}
}
async
fn
validation_task
(
tokenizer
:
Tokenizer
,
mut
receiver
:
mpsc
::
Receiver
<
ValidationRequest
>
)
{
/// Validation task
while
let
Some
((
request
,
response_tx
))
=
receiver
.recv
()
.await
{
/// Load balance the validation requests between multiple validation workers
async
fn
validation_task
(
workers
:
usize
,
tokenizer
:
Tokenizer
,
max_input_length
:
usize
,
mut
receiver
:
mpsc
::
Receiver
<
ValidationRequest
>
,
)
{
let
mut
workers_senders
=
Vec
::
with_capacity
(
workers
);
// Create workers
for
_
in
0
..
workers
{
let
tokenizer_clone
=
tokenizer
.clone
();
// Create channel to communicate with worker
let
(
worker_sender
,
worker_receiver
)
=
mpsc
::
channel
(
workers
);
workers_senders
.push
(
worker_sender
);
// Spawn worker
tokio
::
task
::
spawn_blocking
(
move
||
{
validation_worker
(
tokenizer_clone
,
max_input_length
,
worker_receiver
)
});
}
loop
{
// Load balance requests between workers
for
sender
in
workers_senders
.iter
()
{
if
let
Some
(
validation_request
)
=
receiver
.recv
()
.await
{
sender
.send
(
validation_request
)
.await
.unwrap
();
}
else
{
return
;
}
}
}
}
/// Check the parameters inside the payload and get the number of tokens inside the input using
/// the tokenizer
fn
validation_worker
(
tokenizer
:
TokenizerImpl
<
ModelWrapper
,
NormalizerWrapper
,
PreTokenizerWrapper
,
PostProcessorWrapper
,
DecoderWrapper
,
>
,
max_input_length
:
usize
,
mut
receiver
:
mpsc
::
Receiver
<
ValidationRequest
>
,
)
{
// Loop over requests
while
let
Some
((
request
,
response_tx
))
=
receiver
.blocking_recv
()
{
if
request
.parameters.temperature
<
0.0
{
if
request
.parameters.temperature
<
0.0
{
response_tx
response_tx
.send
(
Err
(
ValidationError
::
Temperature
))
.send
(
Err
(
ValidationError
::
Temperature
))
...
@@ -78,10 +121,11 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
...
@@ -78,10 +121,11 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
continue
;
continue
;
}
}
// Get the number of tokens in the input
let
inputs
=
tokenizer
.encode
(
request
.inputs
.clone
(),
false
)
.unwrap
();
let
inputs
=
tokenizer
.encode
(
request
.inputs
.clone
(),
false
)
.unwrap
();
let
input_length
=
inputs
.len
();
let
input_length
=
inputs
.len
();
if
input_length
>
1000
{
if
input_length
>
max_input_length
{
response_tx
response_tx
.send
(
Err
(
ValidationError
::
InputLength
(
input_length
)))
.send
(
Err
(
ValidationError
::
InputLength
(
input_length
)))
.unwrap_or
(());
.unwrap_or
(());
...
@@ -91,3 +135,28 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
...
@@ -91,3 +135,28 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
response_tx
.send
(
Ok
((
input_length
,
request
)))
.unwrap_or
(());
response_tx
.send
(
Ok
((
input_length
,
request
)))
.unwrap_or
(());
}
}
}
}
type
ValidationRequest
=
(
GenerateRequest
,
oneshot
::
Sender
<
Result
<
(
usize
,
GenerateRequest
),
ValidationError
>>
,
);
#[derive(Error,
Debug)]
pub
enum
ValidationError
{
#[error(
"Temperature must be strictly positive"
)]
Temperature
,
#[error(
"Top p must be <= 0.0 or > 1.0"
)]
TopP
,
#[error(
"Top k must be strictly positive"
)]
TopK
,
#[error(
"Max New Tokens must be < 512"
)]
MaxNewTokens
,
#[error(
"Inputs must have less than 1000 tokens. Given: {0}"
)]
InputLength
(
usize
),
}
impl
From
<
ValidationError
>
for
(
StatusCode
,
String
)
{
fn
from
(
err
:
ValidationError
)
->
Self
{
(
StatusCode
::
BAD_REQUEST
,
err
.to_string
())
}
}
run.sh
deleted
100644 → 0
View file @
92c1ecd0
#!/usr/bin/env bash
server_cmd
=
"bloom-inference-server launcher
$MODEL_NAME
--num-gpus
$NUM_GPUS
--shard-directory
$MODEL_BASE_PATH
"
# Run in background
$server_cmd
2>&1
>
/dev/null &
# Check if server is running by checking if the unix socket is created
FILE
=
/tmp/bloom-inference-0
while
:
do
if
test
-S
"
$FILE
"
;
then
echo
"Text Generation Python gRPC server started"
break
else
echo
"Waiting for Text Generation Python gRPC server to start"
sleep
5
fi
done
sleep
1
# Run in background
text-generation-router &
# Wait for any process to exit
wait
-n
# Exit with status of process that exited first
exit
$?
\ No newline at end of file
router/
rust-toolchain.toml
→
rust-toolchain.toml
View file @
f16f2f5a
File moved
server/.gitignore
0 → 100644
View file @
f16f2f5a
# Byte-compiled / optimized / DLL files
__pycache__/
bloom_inference/__pycache__/
bloom_inference/pb/__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
server/Makefile
View file @
f16f2f5a
...
@@ -4,17 +4,28 @@ gen-server:
...
@@ -4,17 +4,28 @@ gen-server:
find bloom_inference/pb/
-type
f
-name
"*.py"
-print0
-exec
sed
-i
-e
's/^\(import.*pb2\)/from . \1/g'
{}
\;
find bloom_inference/pb/
-type
f
-name
"*.py"
-print0
-exec
sed
-i
-e
's/^\(import.*pb2\)/from . \1/g'
{}
\;
touch
bloom_inference/pb/__init__.py
touch
bloom_inference/pb/__init__.py
unit-tests
:
install-transformers
:
python
-m
pytest
--cov
=
bloom_inference tests
# Install specific version of transformers
rm
transformers
||
true
wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
rm
46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
mv
transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 transformers
cd
transformers
&&
python setup.py
install
unit-tests-reporting
:
install-torch
:
python
-m
pytest
--junitxml
=
report.xml
--cov
=
bloom_inference tests
# Install specific version of torch
pip
install
torch
--extra-index-url
https://download.pytorch.org/whl/cu116
--no-cache-dir
pip-install
:
pip-install
:
pip
install
grpcio-tools
pip
install
grpcio-tools
make gen-server
make gen-server
make install-torch
make install-transformers
pip
install
.
pip
install
.
install
:
install
:
poetry
install
poetry
install
make gen-server
make gen-server
\ No newline at end of file
make install-torch
make install-transformers
server/bloom_inference/cli.py
View file @
f16f2f5a
import
os
import
typer
import
typer
from
pathlib
import
Path
from
pathlib
import
Path
from
torch.distributed.launcher
import
launch_agent
,
LaunchConfig
from
typing
import
Optional
from
typing
import
Optional
from
bloom_inference
import
server
from
bloom_inference
import
prepare_weights
,
server
app
=
typer
.
Typer
()
app
=
typer
.
Typer
()
@
app
.
command
()
@
app
.
command
()
def
launcher
(
def
serve
(
model_name
:
str
,
model_name
:
str
,
num_gpus
:
int
=
1
,
sharded
:
bool
=
False
,
shard_directory
:
Optional
[
Path
]
=
None
,
shard_directory
:
Optional
[
Path
]
=
None
,
uds_path
:
Path
=
"/tmp/bloom-inference"
,
):
):
if
num_gpus
==
1
:
if
sharded
:
serve
(
model_name
,
False
,
shard_directory
)
assert
(
shard_directory
is
not
None
else
:
),
"shard_directory must be set when sharded is True"
config
=
LaunchConfig
(
assert
(
min_nodes
=
1
,
os
.
getenv
(
"RANK"
,
None
)
is
not
None
max_nodes
=
1
,
),
"RANK must be set when sharded is True"
nproc_per_node
=
num_gpus
,
assert
(
rdzv_backend
=
"c10d"
,
os
.
getenv
(
"WORLD_SIZE"
,
None
)
is
not
None
max_restarts
=
0
,
),
"WORLD_SIZE must be set when sharded is True"
)
assert
(
launch_agent
(
config
,
server
.
serve
,
[
model_name
,
True
,
shard_directory
])
os
.
getenv
(
"MASTER_ADDR"
,
None
)
is
not
None
),
"MASTER_ADDR must be set when sharded is True"
assert
(
os
.
getenv
(
"MASTER_PORT"
,
None
)
is
not
None
),
"MASTER_PORT must be set when sharded is True"
server
.
serve
(
model_name
,
sharded
,
uds_path
,
shard_directory
)
@
app
.
command
()
@
app
.
command
()
def
serve
(
def
prepare_weights
(
model_name
:
str
,
model_name
:
str
,
sharded
:
bool
=
False
,
shard_directory
:
Path
,
shard_directory
:
Optional
[
Path
]
=
None
,
cache_directory
:
Path
,
num_shard
:
int
=
1
,
):
):
server
.
serve
(
model_name
,
sharded
,
shard_directory
)
prepare_weights
.
prepare_weights
(
model_name
,
cache_directory
,
shard_directory
,
num_shard
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
server/bloom_inference/model.py
View file @
f16f2f5a
...
@@ -24,6 +24,7 @@ torch.manual_seed(0)
...
@@ -24,6 +24,7 @@ torch.manual_seed(0)
class
Batch
:
class
Batch
:
batch_id
:
int
batch_id
:
int
requests
:
List
[
generate_pb2
.
Request
]
requests
:
List
[
generate_pb2
.
Request
]
all_input_lengths
:
List
[
int
]
input_ids
:
Dict
[
str
,
torch
.
Tensor
]
input_ids
:
Dict
[
str
,
torch
.
Tensor
]
all_input_ids
:
List
[
torch
.
Tensor
]
all_input_ids
:
List
[
torch
.
Tensor
]
next_token_choosers
:
List
[
NextTokenChooser
]
next_token_choosers
:
List
[
NextTokenChooser
]
...
@@ -46,12 +47,12 @@ class Batch:
...
@@ -46,12 +47,12 @@ class Batch:
inputs
=
[]
inputs
=
[]
next_token_choosers
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
stopping_criterias
=
[]
input_lengths
=
[]
all_
input_lengths
=
[]
# Parse batch
# Parse batch
for
r
in
pb
.
requests
:
for
r
in
pb
.
requests
:
inputs
.
append
(
r
.
inputs
)
inputs
.
append
(
r
.
inputs
)
input_lengths
.
append
(
r
.
input_length
)
all_
input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
next_token_choosers
.
append
(
NextTokenChooser
(
NextTokenChooser
(
temperature
=
r
.
parameters
.
temperature
,
temperature
=
r
.
parameters
.
temperature
,
...
@@ -63,17 +64,12 @@ class Batch:
...
@@ -63,17 +64,12 @@ class Batch:
stopping_criterias
.
append
(
StoppingCriteria
(
max_new_tokens
=
r
.
max_new_tokens
))
stopping_criterias
.
append
(
StoppingCriteria
(
max_new_tokens
=
r
.
max_new_tokens
))
input_ids
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
).
to
(
device
)
input_ids
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
).
to
(
device
)
# Remove padding from all_input_ids
all_input_ids
=
input_ids
[
"input_ids"
].
unsqueeze
(
-
1
)
all_input_ids
=
[
input_ids
.
squeeze
(
0
)[
-
length
:].
unsqueeze
(
-
1
)
for
length
,
input_ids
in
zip
(
input_lengths
,
input_ids
[
"input_ids"
].
split
(
1
,
dim
=
0
)
)
]
return
cls
(
return
cls
(
batch_id
=
pb
.
id
,
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
requests
=
pb
.
requests
,
all_input_lengths
=
all_input_lengths
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
all_input_ids
=
all_input_ids
,
all_input_ids
=
all_input_ids
,
next_token_choosers
=
next_token_choosers
,
next_token_choosers
=
next_token_choosers
,
...
@@ -91,6 +87,7 @@ class Batch:
...
@@ -91,6 +87,7 @@ class Batch:
# Batch attributes
# Batch attributes
input_ids
=
{
"input_ids"
:
None
,
"attention_mask"
:
None
,
"past_key_values"
:
[]}
input_ids
=
{
"input_ids"
:
None
,
"attention_mask"
:
None
,
"past_key_values"
:
[]}
requests
=
[]
requests
=
[]
all_input_lengths
=
[]
all_input_ids
=
[]
all_input_ids
=
[]
next_token_choosers
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
stopping_criterias
=
[]
...
@@ -100,6 +97,7 @@ class Batch:
...
@@ -100,6 +97,7 @@ class Batch:
start_index
=
0
start_index
=
0
for
i
,
batch
in
enumerate
(
batches
):
for
i
,
batch
in
enumerate
(
batches
):
requests
.
extend
(
batch
.
requests
)
requests
.
extend
(
batch
.
requests
)
all_input_lengths
.
extend
(
batch
.
all_input_lengths
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
...
@@ -198,6 +196,7 @@ class Batch:
...
@@ -198,6 +196,7 @@ class Batch:
return
cls
(
return
cls
(
batch_id
=
batches
[
0
].
batch_id
,
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
requests
=
requests
,
all_input_lengths
=
all_input_lengths
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
all_input_ids
=
all_input_ids
,
all_input_ids
=
all_input_ids
,
next_token_choosers
=
next_token_choosers
,
next_token_choosers
=
next_token_choosers
,
...
@@ -227,7 +226,10 @@ class BLOOM:
...
@@ -227,7 +226,10 @@ class BLOOM:
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
self
.
model
=
(
self
.
model
=
(
AutoModelForCausalLM
.
from_pretrained
(
model_name
).
eval
().
to
(
self
.
device
).
to
(
dtype
)
AutoModelForCausalLM
.
from_pretrained
(
model_name
)
.
eval
()
.
to
(
self
.
device
)
.
to
(
dtype
)
)
)
self
.
num_heads
=
self
.
model
.
base_model
.
num_heads
self
.
num_heads
=
self
.
model
.
base_model
.
num_heads
...
@@ -253,6 +255,7 @@ class BLOOM:
...
@@ -253,6 +255,7 @@ class BLOOM:
# New input_ids for next forward
# New input_ids for next forward
next_batch_input_ids
=
[]
next_batch_input_ids
=
[]
next_batch_all_input_ids
=
[]
next_batch_all_input_ids
=
[]
next_all_input_lengths
=
[]
next_batch_size
=
0
next_batch_size
=
0
next_batch_max_sequence_length
=
0
next_batch_max_sequence_length
=
0
...
@@ -263,6 +266,7 @@ class BLOOM:
...
@@ -263,6 +266,7 @@ class BLOOM:
# Zipped iterator
# Zipped iterator
iterator
=
zip
(
iterator
=
zip
(
batch
.
requests
,
batch
.
requests
,
batch
.
all_input_lengths
,
outputs
.
logits
,
outputs
.
logits
,
batch
.
next_token_choosers
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
stopping_criterias
,
...
@@ -272,6 +276,7 @@ class BLOOM:
...
@@ -272,6 +276,7 @@ class BLOOM:
# For each member of the batch
# For each member of the batch
for
i
,
(
for
i
,
(
request
,
request
,
input_length
,
logits
,
logits
,
next_token_chooser
,
next_token_chooser
,
stopping_criteria
,
stopping_criteria
,
...
@@ -302,8 +307,10 @@ class BLOOM:
...
@@ -302,8 +307,10 @@ class BLOOM:
next_batch_input_ids
.
append
(
next_token
)
next_batch_input_ids
.
append
(
next_token
)
next_batch_all_input_ids
.
append
(
all_tokens
)
next_batch_all_input_ids
.
append
(
all_tokens
)
next_batch_size
+=
1
next_batch_size
+=
1
new_input_length
=
input_length
+
1
next_all_input_lengths
.
append
(
new_input_length
)
next_batch_max_sequence_length
=
max
(
next_batch_max_sequence_length
=
max
(
next_batch_max_sequence_length
,
len
(
all_tokens
)
next_batch_max_sequence_length
,
new_input_length
)
)
# We finished all generations in the batch; there is no next batch
# We finished all generations in the batch; there is no next batch
...
@@ -350,6 +357,7 @@ class BLOOM:
...
@@ -350,6 +357,7 @@ class BLOOM:
next_batch
=
Batch
(
next_batch
=
Batch
(
batch_id
=
batch
.
batch_id
,
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
requests
=
next_batch_requests
,
all_input_lengths
=
next_all_input_lengths
,
input_ids
=
next_batch_input_ids
,
input_ids
=
next_batch_input_ids
,
all_input_ids
=
next_batch_all_input_ids
,
all_input_ids
=
next_batch_all_input_ids
,
next_token_choosers
=
next_batch_next_token_choosers
,
next_token_choosers
=
next_batch_next_token_choosers
,
...
@@ -378,7 +386,10 @@ class BLOOMSharded(BLOOM):
...
@@ -378,7 +386,10 @@ class BLOOMSharded(BLOOM):
if
self
.
master
:
if
self
.
master
:
# TODO @thomasw21 do some caching
# TODO @thomasw21 do some caching
shard_state_dict_paths
=
prepare_weights
(
shard_state_dict_paths
=
prepare_weights
(
model_name
,
shard_directory
/
"cache"
,
shard_directory
,
tp_world_size
=
self
.
world_size
model_name
,
shard_directory
/
"cache"
,
shard_directory
,
tp_world_size
=
self
.
world_size
,
)
)
shard_state_dict_paths
=
[
shard_state_dict_paths
=
[
str
(
path
.
absolute
())
for
path
in
shard_state_dict_paths
str
(
path
.
absolute
())
for
path
in
shard_state_dict_paths
...
@@ -443,6 +454,7 @@ class BLOOMSharded(BLOOM):
...
@@ -443,6 +454,7 @@ class BLOOMSharded(BLOOM):
use_cache
=
True
,
use_cache
=
True
,
)
)
# Logits are sharded, so we need to gather them
logits_shard
=
outputs
.
logits
[:,
-
1
,
:].
contiguous
()
logits_shard
=
outputs
.
logits
[:,
-
1
,
:].
contiguous
()
batch_size
,
vocab_shard_size
=
logits_shard
.
shape
batch_size
,
vocab_shard_size
=
logits_shard
.
shape
...
...
server/bloom_inference/pb/.gitignore
View file @
f16f2f5a
*.py
*.py
*.py-e
*.py-e
\ No newline at end of file
server/bloom_inference/prepare_weights.py
View file @
f16f2f5a
...
@@ -14,15 +14,15 @@ from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status
...
@@ -14,15 +14,15 @@ from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status
def
match_suffix
(
text
,
suffix
):
def
match_suffix
(
text
,
suffix
):
return
text
[
-
len
(
suffix
):]
==
suffix
return
text
[
-
len
(
suffix
)
:]
==
suffix
def
http_get
(
def
http_get
(
url
:
str
,
url
:
str
,
temp_file
:
BinaryIO
,
temp_file
:
BinaryIO
,
*
,
*
,
timeout
=
10.0
,
timeout
=
10.0
,
max_retries
=
0
,
max_retries
=
0
,
):
):
"""
"""
Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
...
@@ -54,7 +54,9 @@ def cache_download_url(url: str, root_dir: Path):
...
@@ -54,7 +54,9 @@ def cache_download_url(url: str, root_dir: Path):
return
filename
return
filename
def
prepare_weights
(
model_name
:
str
,
cache_path
:
Path
,
save_path
:
Path
,
tp_world_size
:
int
):
def
prepare_weights
(
model_name
:
str
,
cache_path
:
Path
,
save_path
:
Path
,
tp_world_size
:
int
):
save_paths
=
[
save_paths
=
[
save_path
/
f
"
{
model_name
}
_tp-rank-
{
tp_rank
}
-of-
{
tp_world_size
}
.pty"
save_path
/
f
"
{
model_name
}
_tp-rank-
{
tp_rank
}
-of-
{
tp_world_size
}
.pty"
for
tp_rank
in
range
(
tp_world_size
)
for
tp_rank
in
range
(
tp_world_size
)
...
@@ -68,6 +70,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
...
@@ -68,6 +70,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
if
model_name
==
"bigscience/bloom-560m"
:
if
model_name
==
"bigscience/bloom-560m"
:
url
=
hf_hub_url
(
model_name
,
filename
=
"pytorch_model.bin"
)
url
=
hf_hub_url
(
model_name
,
filename
=
"pytorch_model.bin"
)
cache_download_url
(
url
,
cache_path
)
cache_download_url
(
url
,
cache_path
)
elif
model_name
==
"bigscience/bloom"
:
elif
model_name
==
"bigscience/bloom"
:
url
=
hf_hub_url
(
model_name
,
filename
=
"pytorch_model.bin.index.json"
)
url
=
hf_hub_url
(
model_name
,
filename
=
"pytorch_model.bin.index.json"
)
index_path
=
cache_download_url
(
url
,
cache_path
)
index_path
=
cache_download_url
(
url
,
cache_path
)
...
@@ -75,10 +78,14 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
...
@@ -75,10 +78,14 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
index
=
json
.
load
(
f
)
index
=
json
.
load
(
f
)
# Get unique file names
# Get unique file names
weight_files
=
list
(
set
([
filename
for
filename
in
index
[
"weight_map"
].
values
()]))
weight_files
=
list
(
set
([
filename
for
filename
in
index
[
"weight_map"
].
values
()])
)
urls
=
[
hf_hub_url
(
model_name
,
filename
=
filename
)
for
filename
in
weight_files
]
urls
=
[
hf_hub_url
(
model_name
,
filename
=
filename
)
for
filename
in
weight_files
]
Parallel
(
n_jobs
=
5
)(
delayed
(
cache_download_url
)(
url
,
cache_path
)
for
url
in
tqdm
(
urls
))
Parallel
(
n_jobs
=
5
)(
delayed
(
cache_download_url
)(
url
,
cache_path
)
for
url
in
tqdm
(
urls
)
)
else
:
else
:
raise
ValueError
(
f
"Unknown model name:
{
model_name
}
"
)
raise
ValueError
(
f
"Unknown model name:
{
model_name
}
"
)
...
@@ -91,14 +98,14 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
...
@@ -91,14 +98,14 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
for
state_name
in
keys
:
for
state_name
in
keys
:
state
=
state_dict
[
state_name
]
state
=
state_dict
[
state_name
]
if
any
(
if
any
(
match_suffix
(
state_name
,
candidate
)
match_suffix
(
state_name
,
candidate
)
for
candidate
in
[
for
candidate
in
[
"self_attention.query_key_value.weight"
,
"self_attention.query_key_value.weight"
,
"self_attention.query_key_value.bias"
,
"self_attention.query_key_value.bias"
,
"mlp.dense_h_to_4h.weight"
,
"mlp.dense_h_to_4h.weight"
,
"mlp.dense_h_to_4h.bias"
,
"mlp.dense_h_to_4h.bias"
,
"word_embeddings.weight"
,
"word_embeddings.weight"
,
]
]
):
):
output_size
=
state
.
shape
[
0
]
output_size
=
state
.
shape
[
0
]
assert
output_size
%
tp_world_size
==
0
assert
output_size
%
tp_world_size
==
0
...
@@ -107,7 +114,9 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
...
@@ -107,7 +114,9 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
assert
len
(
sharded_weights
)
==
tp_world_size
assert
len
(
sharded_weights
)
==
tp_world_size
for
tp_rank
,
shard
in
enumerate
(
sharded_weights
):
for
tp_rank
,
shard
in
enumerate
(
sharded_weights
):
shards_state_dicts
[
tp_rank
][
"transformer."
+
state_name
]
=
shard
.
detach
().
clone
()
shards_state_dicts
[
tp_rank
][
"transformer."
+
state_name
]
=
shard
.
detach
().
clone
()
elif
match_suffix
(
state_name
,
"lm_head.weight"
):
elif
match_suffix
(
state_name
,
"lm_head.weight"
):
output_size
=
state
.
shape
[
0
]
output_size
=
state
.
shape
[
0
]
...
@@ -120,11 +129,11 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
...
@@ -120,11 +129,11 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
shards_state_dicts
[
tp_rank
][
state_name
]
=
shard
.
detach
().
clone
()
shards_state_dicts
[
tp_rank
][
state_name
]
=
shard
.
detach
().
clone
()
elif
any
(
elif
any
(
match_suffix
(
state_name
,
candidate
)
match_suffix
(
state_name
,
candidate
)
for
candidate
in
[
for
candidate
in
[
"self_attention.dense.weight"
,
"self_attention.dense.weight"
,
"mlp.dense_4h_to_h.weight"
,
"mlp.dense_4h_to_h.weight"
,
]
]
):
):
input_size
=
state
.
shape
[
1
]
input_size
=
state
.
shape
[
1
]
assert
input_size
%
tp_world_size
==
0
assert
input_size
%
tp_world_size
==
0
...
@@ -132,23 +141,31 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
...
@@ -132,23 +141,31 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
sharded_weights
=
torch
.
split
(
state
,
block_size
,
dim
=
1
)
sharded_weights
=
torch
.
split
(
state
,
block_size
,
dim
=
1
)
assert
len
(
sharded_weights
)
==
tp_world_size
assert
len
(
sharded_weights
)
==
tp_world_size
for
tp_rank
,
shard
in
enumerate
(
sharded_weights
):
for
tp_rank
,
shard
in
enumerate
(
sharded_weights
):
shards_state_dicts
[
tp_rank
][
"transformer."
+
state_name
]
=
shard
.
detach
().
clone
()
shards_state_dicts
[
tp_rank
][
"transformer."
+
state_name
]
=
shard
.
detach
().
clone
()
elif
any
(
elif
any
(
match_suffix
(
state_name
,
candidate
)
match_suffix
(
state_name
,
candidate
)
for
candidate
in
[
for
candidate
in
[
"self_attention.dense.bias"
,
"self_attention.dense.bias"
,
"mlp.dense_4h_to_h.bias"
,
"mlp.dense_4h_to_h.bias"
,
]
]
):
):
shards_state_dicts
[
0
][
"transformer."
+
state_name
]
=
state
.
detach
().
clone
()
shards_state_dicts
[
0
][
"transformer."
+
state_name
]
=
state
.
detach
().
clone
()
for
tp_rank
in
range
(
1
,
tp_world_size
):
for
tp_rank
in
range
(
1
,
tp_world_size
):
shards_state_dicts
[
tp_rank
][
"transformer."
+
state_name
]
=
torch
.
zeros_like
(
state
)
shards_state_dicts
[
tp_rank
][
"transformer."
+
state_name
]
=
torch
.
zeros_like
(
state
)
else
:
else
:
# We duplicate parameters across tp ranks
# We duplicate parameters across tp ranks
for
tp_rank
in
range
(
tp_world_size
):
for
tp_rank
in
range
(
tp_world_size
):
shards_state_dicts
[
tp_rank
][
"transformer."
+
state_name
]
=
state
.
detach
().
clone
()
shards_state_dicts
[
tp_rank
][
"transformer."
+
state_name
]
=
state
.
detach
().
clone
()
del
state_dict
[
state_name
]
# delete key from state_dict
del
state_dict
[
state_name
]
# delete key from state_dict
del
state
# delete tensor
del
state
# delete tensor
...
@@ -156,7 +173,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
...
@@ -156,7 +173,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
# we save state_dict
# we save state_dict
for
tp_rank
,
(
save_path
,
shard_state_dict
)
in
enumerate
(
for
tp_rank
,
(
save_path
,
shard_state_dict
)
in
enumerate
(
zip
(
save_paths
,
shards_state_dicts
)
zip
(
save_paths
,
shards_state_dicts
)
):
):
save_paths
.
append
(
save_path
)
save_paths
.
append
(
save_path
)
save_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
save_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
...
@@ -166,17 +183,3 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
...
@@ -166,17 +183,3 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
torch
.
save
(
shard_state_dict
,
save_path
)
torch
.
save
(
shard_state_dict
,
save_path
)
return
save_paths
return
save_paths
if
__name__
==
"__main__"
:
from
argparse
import
ArgumentParser
parser
=
ArgumentParser
()
parser
.
add_argument
(
"--model-name"
,
required
=
True
,
type
=
str
)
parser
.
add_argument
(
"--cache-path"
,
required
=
True
,
type
=
str
)
parser
.
add_argument
(
"--save-path"
,
required
=
True
,
type
=
str
)
parser
.
add_argument
(
"--world-size"
,
required
=
True
,
type
=
int
)
args
=
parser
.
parse_args
()
prepare_weights
(
args
.
model_name
,
Path
(
args
.
cache_path
),
Path
(
args
.
save_path
),
args
.
world_size
)
server/bloom_inference/server.py
View file @
f16f2f5a
...
@@ -64,70 +64,31 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
...
@@ -64,70 +64,31 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch
=
next_batch
.
to_pb
()
if
next_batch
else
None
,
batch
=
next_batch
.
to_pb
()
if
next_batch
else
None
,
)
)
async
def
GenerateUntilFinished
(
self
,
request
,
context
):
batch
=
Batch
.
from_pb
(
request
.
batch
,
self
.
model
.
tokenizer
,
self
.
model
.
device
)
generated_texts
=
[]
while
not
generated_texts
:
generated_texts
,
next_batch
=
self
.
model
.
generate_token
(
batch
)
batch
=
next_batch
self
.
cache
.
set
(
next_batch
)
return
generate_pb2
.
GenerateUntilFinishedResponse
(
generated_texts
=
[
generated_text
.
to_pb
()
for
generated_text
in
generated_texts
],
batch
=
next_batch
.
to_pb
()
if
next_batch
else
None
,
)
async
def
GenerateUntilFinishedWithCache
(
self
,
request
,
context
):
if
len
(
request
.
batches
)
==
0
:
raise
ValueError
(
"Must provide at least one batch"
)
batches
=
[]
for
batch_pb
in
request
.
batches
:
batch
=
self
.
cache
.
pop
(
batch_pb
.
id
)
if
batch
is
None
:
raise
ValueError
(
f
"Batch ID
{
batch_pb
.
id
}
not found in cache."
)
batches
.
append
(
batch
)
if
len
(
batches
)
>
1
:
batch
=
Batch
.
concatenate
(
batches
)
else
:
batch
=
batches
[
0
]
generated_texts
=
[]
while
not
generated_texts
:
generated_texts
,
next_batch
=
self
.
model
.
generate_token
(
batch
)
batch
=
next_batch
self
.
cache
.
set
(
next_batch
)
return
generate_pb2
.
GenerateUntilFinishedWithCacheResponse
(
generated_texts
=
[
generated_text
.
to_pb
()
for
generated_text
in
generated_texts
],
batch
=
next_batch
.
to_pb
()
if
next_batch
else
None
,
)
def
serve
(
model_name
,
sharded
,
shard_directory
):
def
serve
(
model_name
:
str
,
sharded
:
bool
,
uds_path
:
Path
,
shard_directory
:
Optional
[
Path
]
=
None
,
):
async
def
serve_inner
(
async
def
serve_inner
(
model_name
:
str
,
model_name
:
str
,
sharded
:
bool
=
False
,
sharded
:
bool
=
False
,
shard_directory
:
Optional
[
Path
]
=
None
,
shard_directory
:
Optional
[
Path
]
=
None
,
):
):
unix_socket_template
=
"unix://
/tmp/bloom-inference
-{}"
unix_socket_template
=
"unix://
{}
-{}"
if
sharded
:
if
sharded
:
if
shard_directory
is
None
:
if
shard_directory
is
None
:
raise
ValueError
(
"shard_directory must be set when sharded is True"
)
raise
ValueError
(
"shard_directory must be set when sharded is True"
)
model
=
BLOOMSharded
(
model_name
,
shard_directory
)
model
=
BLOOMSharded
(
model_name
,
shard_directory
)
server_urls
=
[
server_urls
=
[
unix_socket_template
.
format
(
rank
)
for
rank
in
range
(
model
.
world_size
)
unix_socket_template
.
format
(
uds_path
,
rank
)
for
rank
in
range
(
model
.
world_size
)
]
]
local_url
=
unix_socket_template
.
format
(
model
.
rank
)
local_url
=
server_urls
[
model
.
rank
]
else
:
else
:
model
=
BLOOM
(
model_name
)
model
=
BLOOM
(
model_name
)
local_url
=
unix_socket_template
.
format
(
0
)
local_url
=
unix_socket_template
.
format
(
uds_path
,
0
)
server_urls
=
[
local_url
]
server_urls
=
[
local_url
]
server
=
aio
.
server
()
server
=
aio
.
server
()
...
@@ -142,6 +103,10 @@ def serve(model_name, sharded, shard_directory):
...
@@ -142,6 +103,10 @@ def serve(model_name, sharded, shard_directory):
server
.
add_insecure_port
(
local_url
)
server
.
add_insecure_port
(
local_url
)
await
server
.
start
()
await
server
.
start
()
print
(
"Server started at {}"
.
format
(
local_url
))
print
(
"Server started at {}"
.
format
(
local_url
))
await
server
.
wait_for_termination
()
try
:
await
server
.
wait_for_termination
()
except
KeyboardInterrupt
:
print
(
"Signal received. Shutting down"
)
await
server
.
stop
(
0
)
asyncio
.
run
(
serve_inner
(
model_name
,
sharded
,
shard_directory
))
asyncio
.
run
(
serve_inner
(
model_name
,
sharded
,
shard_directory
))
server/bloom_inference/utils.py
View file @
f16f2f5a
...
@@ -82,7 +82,6 @@ def initialize_torch_distributed():
...
@@ -82,7 +82,6 @@ def initialize_torch_distributed():
world_size
=
world_size
,
world_size
=
world_size
,
rank
=
rank
,
rank
=
rank
,
timeout
=
timedelta
(
seconds
=
60
),
timeout
=
timedelta
(
seconds
=
60
),
init_method
=
"tcp://localhost:6000"
,
)
)
return
torch
.
distributed
.
distributed_c10d
.
_get_default_group
(),
rank
,
world_size
return
torch
.
distributed
.
distributed_c10d
.
_get_default_group
(),
rank
,
world_size
...
...
server/poetry.lock
View file @
f16f2f5a
...
@@ -205,7 +205,7 @@ python-versions = ">=3.7"
...
@@ -205,7 +205,7 @@ python-versions = ">=3.7"
[metadata]
[metadata]
lock-version = "1.1"
lock-version = "1.1"
python-versions = "^3.9"
python-versions = "^3.9"
content-hash = "
f3dc5b2420183f2e7e9257e372489409d7bd26d1dcc535fc2558ebca50c988c2
"
content-hash = "
a4eef5f52e8d046aa883082c865b0865047f611a3240b18250487d4b6e831496
"
[metadata.files]
[metadata.files]
accelerate = [
accelerate = [
...
...
server/pyproject.toml
View file @
f16f2f5a
...
@@ -11,7 +11,6 @@ bloom-inference-server = 'bloom_inference.cli:app'
...
@@ -11,7 +11,6 @@ bloom-inference-server = 'bloom_inference.cli:app'
python
=
"^3.9"
python
=
"^3.9"
protobuf
=
"^4.21.7"
protobuf
=
"^4.21.7"
grpcio
=
"^1.49.1"
grpcio
=
"^1.49.1"
torch
=
"^1.12.1"
typer
=
"^0.6.1"
typer
=
"^0.6.1"
grpcio-reflection
=
"^1.49.1"
grpcio-reflection
=
"^1.49.1"
accelerate
=
"^0.12.0"
accelerate
=
"^0.12.0"
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment