Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
b4024edd
Unverified
Commit
b4024edd
authored
Jul 10, 2023
by
OlivierDehaene
Committed by
GitHub
Jul 10, 2023
Browse files
feat: better errors for warmup and TP (#575)
Close #571
parent
e943a294
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
80 additions
and
23 deletions
+80
-23
router/src/main.rs
router/src/main.rs
+25
-12
router/src/server.rs
router/src/server.rs
+4
-5
server/text_generation_server/models/custom_modeling/bloom_modeling.py
...eneration_server/models/custom_modeling/bloom_modeling.py
+5
-0
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
...ion_server/models/custom_modeling/flash_llama_modeling.py
+5
-0
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
...tion_server/models/custom_modeling/flash_neox_modeling.py
+6
-0
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
...ration_server/models/custom_modeling/flash_rw_modeling.py
+6
-0
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+5
-1
server/text_generation_server/models/custom_modeling/mpt_modeling.py
..._generation_server/models/custom_modeling/mpt_modeling.py
+6
-0
server/text_generation_server/models/custom_modeling/neox_modeling.py
...generation_server/models/custom_modeling/neox_modeling.py
+6
-1
server/text_generation_server/models/custom_modeling/opt_modeling.py
..._generation_server/models/custom_modeling/opt_modeling.py
+5
-1
server/text_generation_server/models/custom_modeling/t5_modeling.py
...t_generation_server/models/custom_modeling/t5_modeling.py
+5
-0
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+2
-3
No files found.
router/src/main.rs
View file @
b4024edd
...
@@ -10,8 +10,9 @@ use opentelemetry_otlp::WithExportConfig;
...
@@ -10,8 +10,9 @@ use opentelemetry_otlp::WithExportConfig;
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
SocketAddr
};
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
SocketAddr
};
use
std
::
path
::
Path
;
use
std
::
path
::
Path
;
use
std
::
time
::
Duration
;
use
std
::
time
::
Duration
;
use
text_generation_client
::
ShardedClient
;
use
text_generation_client
::
{
ClientError
,
ShardedClient
}
;
use
text_generation_router
::{
server
,
HubModelInfo
};
use
text_generation_router
::{
server
,
HubModelInfo
};
use
thiserror
::
Error
;
use
tokenizers
::{
FromPretrainedParameters
,
Tokenizer
};
use
tokenizers
::{
FromPretrainedParameters
,
Tokenizer
};
use
tower_http
::
cors
::
AllowOrigin
;
use
tower_http
::
cors
::
AllowOrigin
;
use
tracing_subscriber
::
layer
::
SubscriberExt
;
use
tracing_subscriber
::
layer
::
SubscriberExt
;
...
@@ -70,7 +71,7 @@ struct Args {
...
@@ -70,7 +71,7 @@ struct Args {
ngrok_password
:
Option
<
String
>
,
ngrok_password
:
Option
<
String
>
,
}
}
fn
main
()
->
Result
<
(),
std
::
io
::
Error
>
{
fn
main
()
->
Result
<
(),
Router
Error
>
{
// Get args
// Get args
let
args
=
Args
::
parse
();
let
args
=
Args
::
parse
();
// Pattern match configuration
// Pattern match configuration
...
@@ -149,8 +150,7 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -149,8 +150,7 @@ fn main() -> Result<(), std::io::Error> {
// Launch Tokio runtime
// Launch Tokio runtime
tokio
::
runtime
::
Builder
::
new_multi_thread
()
tokio
::
runtime
::
Builder
::
new_multi_thread
()
.enable_all
()
.enable_all
()
.build
()
.build
()
?
.unwrap
()
.block_on
(
async
{
.block_on
(
async
{
init_logging
(
otlp_endpoint
,
json_output
);
init_logging
(
otlp_endpoint
,
json_output
);
...
@@ -192,17 +192,14 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -192,17 +192,14 @@ fn main() -> Result<(), std::io::Error> {
// Instantiate sharded client from the master unix socket
// Instantiate sharded client from the master unix socket
let
mut
sharded_client
=
ShardedClient
::
connect_uds
(
master_shard_uds_path
)
let
mut
sharded_client
=
ShardedClient
::
connect_uds
(
master_shard_uds_path
)
.await
.await
.
expect
(
"Could not connect to server"
)
;
.
map_err
(
RouterError
::
Connection
)
?
;
// Clear the cache; useful if the webserver rebooted
// Clear the cache; useful if the webserver rebooted
sharded_client
sharded_client
.clear_cache
(
None
)
.clear_cache
(
None
)
.await
.await
.
expect
(
"Unable to clear c
ache
"
);
.
map_err
(
RouterError
::
C
ache
)
?
;
// Get info from the shard
// Get info from the shard
let
shard_info
=
sharded_client
let
shard_info
=
sharded_client
.info
()
.await
.map_err
(
RouterError
::
Info
)
?
;
.info
()
.await
.expect
(
"Unable to get shard info"
);
// Warmup model
// Warmup model
tracing
::
info!
(
"Warming up model"
);
tracing
::
info!
(
"Warming up model"
);
...
@@ -213,7 +210,7 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -213,7 +210,7 @@ fn main() -> Result<(), std::io::Error> {
max_batch_total_tokens
,
max_batch_total_tokens
,
)
)
.await
.await
.
expect
(
"Unable to warmup model"
)
;
.
map_err
(
RouterError
::
Warmup
)
?
;
tracing
::
info!
(
"Connected"
);
tracing
::
info!
(
"Connected"
);
let
addr
=
match
hostname
.parse
()
{
let
addr
=
match
hostname
.parse
()
{
...
@@ -249,7 +246,7 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -249,7 +246,7 @@ fn main() -> Result<(), std::io::Error> {
ngrok_username
,
ngrok_username
,
ngrok_password
,
ngrok_password
,
)
)
.await
;
.await
?
;
Ok
(())
Ok
(())
})
})
}
}
...
@@ -331,3 +328,19 @@ pub async fn get_model_info(
...
@@ -331,3 +328,19 @@ pub async fn get_model_info(
}
}
None
None
}
}
#[derive(Debug,
Error)]
enum
RouterError
{
#[error(
"Unable to connect to the Python model shards: {0}"
)]
Connection
(
ClientError
),
#[error(
"Unable to clear the Python model shards cache: {0}"
)]
Cache
(
ClientError
),
#[error(
"Unable to get the Python model shards info: {0}"
)]
Info
(
ClientError
),
#[error(
"Unable to warmup the Python model shards: {0}"
)]
Warmup
(
ClientError
),
#[error(
"Tokio runtime failed to start: {0}"
)]
Tokio
(
#[from]
std
::
io
::
Error
),
#[error(
"Axum webserver failed: {0}"
)]
Axum
(
#[from]
axum
::
BoxError
),
}
router/src/server.rs
View file @
b4024edd
...
@@ -527,7 +527,7 @@ pub async fn run(
...
@@ -527,7 +527,7 @@ pub async fn run(
ngrok_domain
:
Option
<
String
>
,
ngrok_domain
:
Option
<
String
>
,
ngrok_username
:
Option
<
String
>
,
ngrok_username
:
Option
<
String
>
,
ngrok_password
:
Option
<
String
>
,
ngrok_password
:
Option
<
String
>
,
)
{
)
->
Result
<
(),
axum
::
BoxError
>
{
// OpenAPI documentation
// OpenAPI documentation
#[derive(OpenApi)]
#[derive(OpenApi)]
#[openapi(
#[openapi(
...
@@ -726,8 +726,7 @@ pub async fn run(
...
@@ -726,8 +726,7 @@ pub async fn run(
.serve
(
app
.into_make_service
())
.serve
(
app
.into_make_service
())
//Wait until all requests are finished to shut down
//Wait until all requests are finished to shut down
.with_graceful_shutdown
(
shutdown_signal
())
.with_graceful_shutdown
(
shutdown_signal
())
.await
.await
?
;
.unwrap
();
}
}
#[cfg(not(feature
=
"ngrok"
))]
#[cfg(not(feature
=
"ngrok"
))]
{
{
...
@@ -744,9 +743,9 @@ pub async fn run(
...
@@ -744,9 +743,9 @@ pub async fn run(
.serve
(
app
.into_make_service
())
.serve
(
app
.into_make_service
())
// Wait until all requests are finished to shut down
// Wait until all requests are finished to shut down
.with_graceful_shutdown
(
shutdown_signal
())
.with_graceful_shutdown
(
shutdown_signal
())
.await
.await
?
;
.unwrap
();
}
}
Ok
(())
}
}
/// Shutdown signal handler
/// Shutdown signal handler
...
...
server/text_generation_server/models/custom_modeling/bloom_modeling.py
View file @
b4024edd
...
@@ -256,6 +256,11 @@ class BloomAttention(nn.Module):
...
@@ -256,6 +256,11 @@ class BloomAttention(nn.Module):
self
.
beta
=
1.0
self
.
beta
=
1.0
process_group
=
weights
.
process_group
process_group
=
weights
.
process_group
if
self
.
num_heads
%
process_group
.
size
()
!=
0
:
raise
ValueError
(
f
"`num_heads` must be divisible by `num_shards` (got `num_heads`:
{
self
.
num_heads
}
"
f
"and `num_shards`:
{
process_group
.
size
()
}
"
)
self
.
num_heads
=
self
.
num_heads
//
process_group
.
size
()
self
.
num_heads
=
self
.
num_heads
//
process_group
.
size
()
self
.
query_key_value
=
TensorParallelColumnLinear
.
load
(
self
.
query_key_value
=
TensorParallelColumnLinear
.
load
(
config
=
config
,
config
=
config
,
...
...
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
View file @
b4024edd
...
@@ -112,6 +112,11 @@ class FlashLlamaAttention(torch.nn.Module):
...
@@ -112,6 +112,11 @@ class FlashLlamaAttention(torch.nn.Module):
self
.
softmax_scale
=
self
.
head_size
**-
0.5
self
.
softmax_scale
=
self
.
head_size
**-
0.5
if
self
.
num_heads
%
weights
.
process_group
.
size
()
!=
0
:
raise
ValueError
(
f
"`num_heads` must be divisible by `num_shards` (got `num_heads`:
{
self
.
num_heads
}
"
f
"and `num_shards`:
{
weights
.
process_group
.
size
()
}
"
)
self
.
num_heads
=
self
.
num_heads
//
weights
.
process_group
.
size
()
self
.
num_heads
=
self
.
num_heads
//
weights
.
process_group
.
size
()
self
.
query_key_value
=
TensorParallelColumnLinear
.
load_multi
(
self
.
query_key_value
=
TensorParallelColumnLinear
.
load_multi
(
config
,
config
,
...
...
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
View file @
b4024edd
...
@@ -95,6 +95,12 @@ class FlashNeoxAttention(torch.nn.Module):
...
@@ -95,6 +95,12 @@ class FlashNeoxAttention(torch.nn.Module):
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
head_size
=
hidden_size
//
num_heads
self
.
head_size
=
hidden_size
//
num_heads
if
self
.
num_heads
%
weights
.
process_group
.
size
()
!=
0
:
raise
ValueError
(
f
"`num_heads` must be divisible by `num_shards` (got `num_heads`:
{
self
.
num_heads
}
"
f
"and `num_shards`:
{
weights
.
process_group
.
size
()
}
"
)
self
.
num_heads
=
self
.
num_heads
//
weights
.
process_group
.
size
()
self
.
num_heads
=
self
.
num_heads
//
weights
.
process_group
.
size
()
self
.
rotary_emb
=
PositionRotaryEmbedding
.
load
(
self
.
rotary_emb
=
PositionRotaryEmbedding
.
load
(
...
...
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
View file @
b4024edd
...
@@ -118,6 +118,12 @@ class FlashRWAttention(torch.nn.Module):
...
@@ -118,6 +118,12 @@ class FlashRWAttention(torch.nn.Module):
dim
=
self
.
head_size
,
base
=
10000.0
,
device
=
weights
.
device
dim
=
self
.
head_size
,
base
=
10000.0
,
device
=
weights
.
device
)
)
self
.
softmax_scale
=
self
.
head_size
**
(
-
0.5
)
self
.
softmax_scale
=
self
.
head_size
**
(
-
0.5
)
if
self
.
num_heads
%
weights
.
process_group
.
size
()
!=
0
:
raise
ValueError
(
f
"`num_heads` must be divisible by `num_shards` (got `num_heads`:
{
self
.
num_heads
}
"
f
"and `num_shards`:
{
weights
.
process_group
.
size
()
}
"
)
self
.
num_heads
=
self
.
num_heads
//
weights
.
process_group
.
size
()
self
.
num_heads
=
self
.
num_heads
//
weights
.
process_group
.
size
()
self
.
query_key_value
=
TensorParallelColumnLinear
.
load
(
self
.
query_key_value
=
TensorParallelColumnLinear
.
load
(
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
b4024edd
...
@@ -208,7 +208,11 @@ class FlashMQAttention(torch.nn.Module):
...
@@ -208,7 +208,11 @@ class FlashMQAttention(torch.nn.Module):
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
head_size
=
hidden_size
//
num_heads
self
.
head_size
=
hidden_size
//
num_heads
assert
self
.
num_heads
%
weights
.
process_group
.
size
()
==
0
if
self
.
num_heads
%
weights
.
process_group
.
size
()
!=
0
:
raise
ValueError
(
f
"`num_heads` must be divisible by `num_shards` (got `num_heads`:
{
self
.
num_heads
}
"
f
"and `num_shards`:
{
weights
.
process_group
.
size
()
}
"
)
self
.
num_heads
=
self
.
num_heads
//
weights
.
process_group
.
size
()
self
.
num_heads
=
self
.
num_heads
//
weights
.
process_group
.
size
()
self
.
softmax_scale
=
self
.
head_size
**
(
-
0.5
)
self
.
softmax_scale
=
self
.
head_size
**
(
-
0.5
)
...
...
server/text_generation_server/models/custom_modeling/mpt_modeling.py
View file @
b4024edd
...
@@ -319,6 +319,12 @@ class MultiheadAttention(nn.Module):
...
@@ -319,6 +319,12 @@ class MultiheadAttention(nn.Module):
if
self
.
softmax_scale
is
None
:
if
self
.
softmax_scale
is
None
:
self
.
softmax_scale
=
1
/
math
.
sqrt
(
self
.
d_model
/
self
.
n_heads
)
self
.
softmax_scale
=
1
/
math
.
sqrt
(
self
.
d_model
/
self
.
n_heads
)
self
.
attn_dropout_p
=
config
.
attn_config
[
"attn_pdrop"
]
self
.
attn_dropout_p
=
config
.
attn_config
[
"attn_pdrop"
]
if
self
.
n_heads
%
weights
.
process_group
.
size
()
!=
0
:
raise
ValueError
(
f
"`n_heads` must be divisible by `num_shards` (got `n_heads`:
{
self
.
n_heads
}
"
f
"and `num_shards`:
{
weights
.
process_group
.
size
()
}
"
)
self
.
n_heads
=
self
.
n_heads
//
weights
.
process_group
.
size
()
self
.
n_heads
=
self
.
n_heads
//
weights
.
process_group
.
size
()
self
.
Wqkv
=
load_col
(
self
.
Wqkv
=
load_col
(
config
,
prefix
=
f
"
{
prefix
}
.Wqkv"
,
weights
=
weights
,
bias
=
not
config
.
no_bias
config
,
prefix
=
f
"
{
prefix
}
.Wqkv"
,
weights
=
weights
,
bias
=
not
config
.
no_bias
...
...
server/text_generation_server/models/custom_modeling/neox_modeling.py
View file @
b4024edd
...
@@ -154,7 +154,12 @@ class GPTNeoXAttention(nn.Module):
...
@@ -154,7 +154,12 @@ class GPTNeoXAttention(nn.Module):
torch
.
tensor
(
self
.
head_size
,
dtype
=
torch
.
float32
)
torch
.
tensor
(
self
.
head_size
,
dtype
=
torch
.
float32
)
).
to
(
torch
.
get_default_dtype
())
).
to
(
torch
.
get_default_dtype
())
assert
self
.
num_attention_heads
%
weights
.
process_group
.
size
()
==
0
if
self
.
num_attention_heads
%
weights
.
process_group
.
size
()
!=
0
:
raise
ValueError
(
f
"`num_attention_heads` must be divisible by `num_shards` "
f
"(got `num_attention_heads`:
{
self
.
num_attention_heads
}
"
f
"and `num_shards`:
{
weights
.
process_group
.
size
()
}
"
)
self
.
num_attention_heads
=
(
self
.
num_attention_heads
=
(
self
.
num_attention_heads
//
weights
.
process_group
.
size
()
self
.
num_attention_heads
//
weights
.
process_group
.
size
()
)
)
...
...
server/text_generation_server/models/custom_modeling/opt_modeling.py
View file @
b4024edd
...
@@ -147,7 +147,11 @@ class OPTAttention(nn.Module):
...
@@ -147,7 +147,11 @@ class OPTAttention(nn.Module):
self
.
is_decoder
=
is_decoder
self
.
is_decoder
=
is_decoder
process_group
=
weights
.
process_group
process_group
=
weights
.
process_group
assert
self
.
num_heads
%
process_group
.
size
()
==
0
if
self
.
num_heads
%
weights
.
process_group
.
size
()
!=
0
:
raise
ValueError
(
f
"`num_heads` must be divisible by `num_shards` (got `num_heads`:
{
self
.
num_heads
}
"
f
"and `num_shards`:
{
weights
.
process_group
.
size
()
}
"
)
self
.
num_heads
=
self
.
num_heads
//
process_group
.
size
()
self
.
num_heads
=
self
.
num_heads
//
process_group
.
size
()
self
.
embed_dim
=
self
.
embed_dim
//
process_group
.
size
()
self
.
embed_dim
=
self
.
embed_dim
//
process_group
.
size
()
...
...
server/text_generation_server/models/custom_modeling/t5_modeling.py
View file @
b4024edd
...
@@ -246,6 +246,11 @@ class T5Attention(nn.Module):
...
@@ -246,6 +246,11 @@ class T5Attention(nn.Module):
self
.
o
=
TensorParallelRowLinear
.
load
(
self
.
o
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.o"
,
weights
=
weights
,
bias
=
False
config
,
prefix
=
f
"
{
prefix
}
.o"
,
weights
=
weights
,
bias
=
False
)
)
if
self
.
n_heads
%
weights
.
process_group
.
size
()
!=
0
:
raise
ValueError
(
f
"`n_heads` must be divisible by `num_shards` (got `n_heads`:
{
self
.
n_heads
}
"
f
"and `num_shards`:
{
weights
.
process_group
.
size
()
}
"
)
self
.
n_heads
=
self
.
n_heads
//
process_group
.
size
()
self
.
n_heads
=
self
.
n_heads
//
process_group
.
size
()
self
.
inner_dim
=
self
.
inner_dim
//
process_group
.
size
()
self
.
inner_dim
=
self
.
inner_dim
//
process_group
.
size
()
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
b4024edd
...
@@ -727,12 +727,11 @@ class FlashCausalLM(Model):
...
@@ -727,12 +727,11 @@ class FlashCausalLM(Model):
)
)
_
,
batch
=
self
.
generate_token
(
batch
)
_
,
batch
=
self
.
generate_token
(
batch
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
exception
(
raise
RuntimeError
(
f
"Not enough memory to handle
{
max_total_tokens
}
total tokens with
{
len
(
batch
.
input_ids
)
}
"
f
"Not enough memory to handle
{
max_total_tokens
}
total tokens with
{
len
(
batch
.
input_ids
)
}
"
f
"prefill tokens. "
f
"prefill tokens. "
f
"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
f
"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
)
)
from
e
raise
e
del
batch
del
batch
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment