Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
c58a0c18
Unverified
Commit
c58a0c18
authored
Jul 14, 2023
by
OlivierDehaene
Committed by
GitHub
Jul 14, 2023
Browse files
v0.9.2 (#616)
parent
5b9de4a1
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
262 additions
and
206 deletions
+262
-206
Cargo.lock
Cargo.lock
+167
-135
Cargo.toml
Cargo.toml
+1
-1
docs/openapi.json
docs/openapi.json
+1
-1
launcher/src/main.rs
launcher/src/main.rs
+92
-68
server/pyproject.toml
server/pyproject.toml
+1
-1
No files found.
Cargo.lock
View file @
c58a0c18
This diff is collapsed.
Click to expand it.
Cargo.toml
View file @
c58a0c18
...
@@ -8,7 +8,7 @@ members = [
...
@@ -8,7 +8,7 @@ members = [
]
]
[workspace.package]
[workspace.package]
version
=
"0.9.
1
"
version
=
"0.9.
2
"
edition
=
"2021"
edition
=
"2021"
authors
=
[
"Olivier Dehaene"
]
authors
=
[
"Olivier Dehaene"
]
homepage
=
"https://github.com/huggingface/text-generation-inference"
homepage
=
"https://github.com/huggingface/text-generation-inference"
...
...
docs/openapi.json
View file @
c58a0c18
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
"name"
:
"Apache 2.0"
,
"name"
:
"Apache 2.0"
,
"url"
:
"https://www.apache.org/licenses/LICENSE-2.0"
"url"
:
"https://www.apache.org/licenses/LICENSE-2.0"
},
},
"version"
:
"0.9.
1
"
"version"
:
"0.9.
2
"
},
},
"paths"
:
{
"paths"
:
{
"/"
:
{
"/"
:
{
...
...
launcher/src/main.rs
View file @
c58a0c18
...
@@ -7,7 +7,7 @@ use std::ffi::OsString;
...
@@ -7,7 +7,7 @@ use std::ffi::OsString;
use
std
::
io
::{
BufRead
,
BufReader
,
Read
};
use
std
::
io
::{
BufRead
,
BufReader
,
Read
};
use
std
::
os
::
unix
::
process
::{
CommandExt
,
ExitStatusExt
};
use
std
::
os
::
unix
::
process
::{
CommandExt
,
ExitStatusExt
};
use
std
::
path
::
Path
;
use
std
::
path
::
Path
;
use
std
::
process
::{
Child
,
Command
,
Stdio
};
use
std
::
process
::{
Child
,
Command
,
ExitStatus
,
Stdio
};
use
std
::
sync
::
atomic
::{
AtomicBool
,
Ordering
};
use
std
::
sync
::
atomic
::{
AtomicBool
,
Ordering
};
use
std
::
sync
::
mpsc
::
TryRecvError
;
use
std
::
sync
::
mpsc
::
TryRecvError
;
use
std
::
sync
::{
mpsc
,
Arc
};
use
std
::
sync
::{
mpsc
,
Arc
};
...
@@ -319,7 +319,7 @@ fn shard_manager(
...
@@ -319,7 +319,7 @@ fn shard_manager(
}
}
// Process args
// Process args
let
mut
shard_arg
v
=
vec!
[
let
mut
shard_arg
s
=
vec!
[
"serve"
.to_string
(),
"serve"
.to_string
(),
model_id
,
model_id
,
"--uds-path"
.to_string
(),
"--uds-path"
.to_string
(),
...
@@ -331,77 +331,77 @@ fn shard_manager(
...
@@ -331,77 +331,77 @@ fn shard_manager(
// Activate trust remote code
// Activate trust remote code
if
trust_remote_code
{
if
trust_remote_code
{
shard_arg
v
.push
(
"--trust-remote-code"
.to_string
());
shard_arg
s
.push
(
"--trust-remote-code"
.to_string
());
}
}
// Activate tensor parallelism
// Activate tensor parallelism
if
world_size
>
1
{
if
world_size
>
1
{
shard_arg
v
.push
(
"--sharded"
.to_string
());
shard_arg
s
.push
(
"--sharded"
.to_string
());
}
}
if
let
Some
(
quantize
)
=
quantize
{
if
let
Some
(
quantize
)
=
quantize
{
shard_arg
v
.push
(
"--quantize"
.to_string
());
shard_arg
s
.push
(
"--quantize"
.to_string
());
shard_arg
v
.push
(
quantize
.to_string
())
shard_arg
s
.push
(
quantize
.to_string
())
}
}
if
let
Some
(
dtype
)
=
dtype
{
if
let
Some
(
dtype
)
=
dtype
{
shard_arg
v
.push
(
"--dtype"
.to_string
());
shard_arg
s
.push
(
"--dtype"
.to_string
());
shard_arg
v
.push
(
dtype
.to_string
())
shard_arg
s
.push
(
dtype
.to_string
())
}
}
// Model optional revision
// Model optional revision
if
let
Some
(
revision
)
=
revision
{
if
let
Some
(
revision
)
=
revision
{
shard_arg
v
.push
(
"--revision"
.to_string
());
shard_arg
s
.push
(
"--revision"
.to_string
());
shard_arg
v
.push
(
revision
)
shard_arg
s
.push
(
revision
)
}
}
// OpenTelemetry
// OpenTelemetry
if
let
Some
(
otlp_endpoint
)
=
otlp_endpoint
{
if
let
Some
(
otlp_endpoint
)
=
otlp_endpoint
{
shard_arg
v
.push
(
"--otlp-endpoint"
.to_string
());
shard_arg
s
.push
(
"--otlp-endpoint"
.to_string
());
shard_arg
v
.push
(
otlp_endpoint
);
shard_arg
s
.push
(
otlp_endpoint
);
}
}
// Copy current process env
// Copy current process env
let
mut
env
:
Vec
<
(
OsString
,
OsString
)
>
=
env
::
vars_os
()
.collect
();
let
mut
env
s
:
Vec
<
(
OsString
,
OsString
)
>
=
env
::
vars_os
()
.collect
();
// Use cuda allocator. It leads to less memory fragmentation
// Use cuda allocator. It leads to less memory fragmentation
env
.push
((
env
s
.push
((
"PYTORCH_CUDA_ALLOC_CONF"
.into
(),
"PYTORCH_CUDA_ALLOC_CONF"
.into
(),
"backend:cudaMallocAsync"
.into
(),
"backend:cudaMallocAsync"
.into
(),
));
));
// Torch Distributed Env vars
// Torch Distributed Env vars
env
.push
((
"RANK"
.into
(),
rank
.to_string
()
.into
()));
env
s
.push
((
"RANK"
.into
(),
rank
.to_string
()
.into
()));
env
.push
((
"WORLD_SIZE"
.into
(),
world_size
.to_string
()
.into
()));
env
s
.push
((
"WORLD_SIZE"
.into
(),
world_size
.to_string
()
.into
()));
env
.push
((
"MASTER_ADDR"
.into
(),
master_addr
.into
()));
env
s
.push
((
"MASTER_ADDR"
.into
(),
master_addr
.into
()));
env
.push
((
"MASTER_PORT"
.into
(),
master_port
.to_string
()
.into
()));
env
s
.push
((
"MASTER_PORT"
.into
(),
master_port
.to_string
()
.into
()));
env
.push
((
"NCCL_ASYNC_ERROR_HANDLING"
.into
(),
"1"
.into
()));
env
s
.push
((
"NCCL_ASYNC_ERROR_HANDLING"
.into
(),
"1"
.into
()));
// Safetensors load fast
// Safetensors load fast
env
.push
((
"SAFETENSORS_FAST_GPU"
.into
(),
"1"
.into
()));
env
s
.push
((
"SAFETENSORS_FAST_GPU"
.into
(),
"1"
.into
()));
// Enable hf transfer for insane download speeds
// Enable hf transfer for insane download speeds
let
enable_hf_transfer
=
env
::
var
(
"HF_HUB_ENABLE_HF_TRANSFER"
)
.unwrap_or
(
"1"
.to_string
());
let
enable_hf_transfer
=
env
::
var
(
"HF_HUB_ENABLE_HF_TRANSFER"
)
.unwrap_or
(
"1"
.to_string
());
env
.push
((
env
s
.push
((
"HF_HUB_ENABLE_HF_TRANSFER"
.into
(),
"HF_HUB_ENABLE_HF_TRANSFER"
.into
(),
enable_hf_transfer
.into
(),
enable_hf_transfer
.into
(),
));
));
// Parse Inference API token
// Parse Inference API token
if
let
Ok
(
api_token
)
=
env
::
var
(
"HF_API_TOKEN"
)
{
if
let
Ok
(
api_token
)
=
env
::
var
(
"HF_API_TOKEN"
)
{
env
.push
((
"HUGGING_FACE_HUB_TOKEN"
.into
(),
api_token
.into
()))
env
s
.push
((
"HUGGING_FACE_HUB_TOKEN"
.into
(),
api_token
.into
()))
};
};
// If huggingface_hub_cache is some, pass it to the shard
// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
// Useful when running inside a docker container
if
let
Some
(
huggingface_hub_cache
)
=
huggingface_hub_cache
{
if
let
Some
(
huggingface_hub_cache
)
=
huggingface_hub_cache
{
env
.push
((
"HUGGINGFACE_HUB_CACHE"
.into
(),
huggingface_hub_cache
.into
()));
env
s
.push
((
"HUGGINGFACE_HUB_CACHE"
.into
(),
huggingface_hub_cache
.into
()));
};
};
// If weights_cache_override is some, pass it to the shard
// If weights_cache_override is some, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint
// Useful when running inside a HuggingFace Inference Endpoint
if
let
Some
(
weights_cache_override
)
=
weights_cache_override
{
if
let
Some
(
weights_cache_override
)
=
weights_cache_override
{
env
.push
((
env
s
.push
((
"WEIGHTS_CACHE_OVERRIDE"
.into
(),
"WEIGHTS_CACHE_OVERRIDE"
.into
(),
weights_cache_override
.into
(),
weights_cache_override
.into
(),
));
));
...
@@ -409,24 +409,24 @@ fn shard_manager(
...
@@ -409,24 +409,24 @@ fn shard_manager(
// If disable_custom_kernels is true, pass it to the shard as an env var
// If disable_custom_kernels is true, pass it to the shard as an env var
if
disable_custom_kernels
{
if
disable_custom_kernels
{
env
.push
((
"DISABLE_CUSTOM_KERNELS"
.into
(),
"True"
.into
()))
env
s
.push
((
"DISABLE_CUSTOM_KERNELS"
.into
(),
"True"
.into
()))
}
}
// Watermark Gamma
// Watermark Gamma
if
let
Some
(
watermark_gamma
)
=
watermark_gamma
{
if
let
Some
(
watermark_gamma
)
=
watermark_gamma
{
env
.push
((
"WATERMARK_GAMMA"
.into
(),
watermark_gamma
.to_string
()
.into
()))
env
s
.push
((
"WATERMARK_GAMMA"
.into
(),
watermark_gamma
.to_string
()
.into
()))
}
}
// Watermark Delta
// Watermark Delta
if
let
Some
(
watermark_delta
)
=
watermark_delta
{
if
let
Some
(
watermark_delta
)
=
watermark_delta
{
env
.push
((
"WATERMARK_DELTA"
.into
(),
watermark_delta
.to_string
()
.into
()))
env
s
.push
((
"WATERMARK_DELTA"
.into
(),
watermark_delta
.to_string
()
.into
()))
}
}
// Start process
// Start process
tracing
::
info!
(
"Starting shard {rank}"
);
tracing
::
info!
(
"Starting shard {rank}"
);
let
mut
p
=
match
Command
::
new
(
"text-generation-server"
)
let
mut
p
=
match
Command
::
new
(
"text-generation-server"
)
.args
(
shard_arg
v
)
.args
(
shard_arg
s
)
.envs
(
env
)
.envs
(
env
s
)
.stdout
(
Stdio
::
piped
())
.stdout
(
Stdio
::
piped
())
.stderr
(
Stdio
::
piped
())
.stderr
(
Stdio
::
piped
())
.process_group
(
0
)
.process_group
(
0
)
...
@@ -632,7 +632,7 @@ enum LauncherError {
...
@@ -632,7 +632,7 @@ enum LauncherError {
}
}
fn
download_convert_model
(
args
:
&
Args
,
running
:
Arc
<
AtomicBool
>
)
->
Result
<
(),
LauncherError
>
{
fn
download_convert_model
(
args
:
&
Args
,
running
:
Arc
<
AtomicBool
>
)
->
Result
<
(),
LauncherError
>
{
let
mut
download_arg
v
=
vec!
[
let
mut
download_arg
s
=
vec!
[
"download-weights"
.to_string
(),
"download-weights"
.to_string
(),
args
.model_id
.to_string
(),
args
.model_id
.to_string
(),
"--extension"
.to_string
(),
"--extension"
.to_string
(),
...
@@ -644,35 +644,35 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
...
@@ -644,35 +644,35 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Model optional revision
// Model optional revision
if
let
Some
(
revision
)
=
&
args
.revision
{
if
let
Some
(
revision
)
=
&
args
.revision
{
download_arg
v
.push
(
"--revision"
.to_string
());
download_arg
s
.push
(
"--revision"
.to_string
());
download_arg
v
.push
(
revision
.to_string
())
download_arg
s
.push
(
revision
.to_string
())
}
}
// Copy current process env
// Copy current process env
let
mut
env
:
Vec
<
(
OsString
,
OsString
)
>
=
env
::
vars_os
()
.collect
();
let
mut
env
s
:
Vec
<
(
OsString
,
OsString
)
>
=
env
::
vars_os
()
.collect
();
// If huggingface_hub_cache is set, pass it to the download process
// If huggingface_hub_cache is set, pass it to the download process
// Useful when running inside a docker container
// Useful when running inside a docker container
if
let
Some
(
ref
huggingface_hub_cache
)
=
args
.huggingface_hub_cache
{
if
let
Some
(
ref
huggingface_hub_cache
)
=
args
.huggingface_hub_cache
{
env
.push
((
"HUGGINGFACE_HUB_CACHE"
.into
(),
huggingface_hub_cache
.into
()));
env
s
.push
((
"HUGGINGFACE_HUB_CACHE"
.into
(),
huggingface_hub_cache
.into
()));
};
};
// Enable hf transfer for insane download speeds
// Enable hf transfer for insane download speeds
let
enable_hf_transfer
=
env
::
var
(
"HF_HUB_ENABLE_HF_TRANSFER"
)
.unwrap_or
(
"1"
.to_string
());
let
enable_hf_transfer
=
env
::
var
(
"HF_HUB_ENABLE_HF_TRANSFER"
)
.unwrap_or
(
"1"
.to_string
());
env
.push
((
env
s
.push
((
"HF_HUB_ENABLE_HF_TRANSFER"
.into
(),
"HF_HUB_ENABLE_HF_TRANSFER"
.into
(),
enable_hf_transfer
.into
(),
enable_hf_transfer
.into
(),
));
));
// Parse Inference API token
// Parse Inference API token
if
let
Ok
(
api_token
)
=
env
::
var
(
"HF_API_TOKEN"
)
{
if
let
Ok
(
api_token
)
=
env
::
var
(
"HF_API_TOKEN"
)
{
env
.push
((
"HUGGING_FACE_HUB_TOKEN"
.into
(),
api_token
.into
()))
env
s
.push
((
"HUGGING_FACE_HUB_TOKEN"
.into
(),
api_token
.into
()))
};
};
// If args.weights_cache_override is some, pass it to the download process
// If args.weights_cache_override is some, pass it to the download process
// Useful when running inside a HuggingFace Inference Endpoint
// Useful when running inside a HuggingFace Inference Endpoint
if
let
Some
(
weights_cache_override
)
=
&
args
.weights_cache_override
{
if
let
Some
(
weights_cache_override
)
=
&
args
.weights_cache_override
{
env
.push
((
env
s
.push
((
"WEIGHTS_CACHE_OVERRIDE"
.into
(),
"WEIGHTS_CACHE_OVERRIDE"
.into
(),
weights_cache_override
.into
(),
weights_cache_override
.into
(),
));
));
...
@@ -681,8 +681,8 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
...
@@ -681,8 +681,8 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Start process
// Start process
tracing
::
info!
(
"Starting download process."
);
tracing
::
info!
(
"Starting download process."
);
let
mut
download_process
=
match
Command
::
new
(
"text-generation-server"
)
let
mut
download_process
=
match
Command
::
new
(
"text-generation-server"
)
.args
(
download_arg
v
)
.args
(
download_arg
s
)
.envs
(
env
)
.envs
(
env
s
)
.stdout
(
Stdio
::
piped
())
.stdout
(
Stdio
::
piped
())
.stderr
(
Stdio
::
piped
())
.stderr
(
Stdio
::
piped
())
.process_group
(
0
)
.process_group
(
0
)
...
@@ -738,10 +738,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
...
@@ -738,10 +738,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
return
Err
(
LauncherError
::
DownloadError
);
return
Err
(
LauncherError
::
DownloadError
);
}
}
if
!
running
.load
(
Ordering
::
SeqCst
)
{
if
!
running
.load
(
Ordering
::
SeqCst
)
{
signal
::
kill
(
Pid
::
from_raw
(
download_process
.id
()
as
i32
),
Signal
::
SIGTERM
)
.unwrap
();
terminate
(
"download"
,
download_process
,
Duration
::
from_secs
(
10
))
.unwrap
();
tracing
::
info!
(
"Waiting for download process to gracefully shutdown"
);
download_process
.wait
()
.unwrap
();
tracing
::
info!
(
"Download process terminated"
);
return
Ok
(());
return
Ok
(());
}
}
sleep
(
Duration
::
from_millis
(
100
));
sleep
(
Duration
::
from_millis
(
100
));
...
@@ -844,7 +841,7 @@ fn spawn_webserver(
...
@@ -844,7 +841,7 @@ fn spawn_webserver(
// All shard started
// All shard started
// Start webserver
// Start webserver
tracing
::
info!
(
"Starting Webserver"
);
tracing
::
info!
(
"Starting Webserver"
);
let
mut
arg
v
=
vec!
[
let
mut
router_
arg
s
=
vec!
[
"--max-concurrent-requests"
.to_string
(),
"--max-concurrent-requests"
.to_string
(),
args
.max_concurrent_requests
.to_string
(),
args
.max_concurrent_requests
.to_string
(),
"--max-best-of"
.to_string
(),
"--max-best-of"
.to_string
(),
...
@@ -877,24 +874,24 @@ fn spawn_webserver(
...
@@ -877,24 +874,24 @@ fn spawn_webserver(
// Model optional revision
// Model optional revision
if
let
Some
(
ref
revision
)
=
args
.revision
{
if
let
Some
(
ref
revision
)
=
args
.revision
{
arg
v
.push
(
"--revision"
.to_string
());
router_
arg
s
.push
(
"--revision"
.to_string
());
arg
v
.push
(
revision
.to_string
())
router_
arg
s
.push
(
revision
.to_string
())
}
}
if
args
.json_output
{
if
args
.json_output
{
arg
v
.push
(
"--json-output"
.to_string
());
router_
arg
s
.push
(
"--json-output"
.to_string
());
}
}
// OpenTelemetry
// OpenTelemetry
if
let
Some
(
otlp_endpoint
)
=
args
.otlp_endpoint
{
if
let
Some
(
otlp_endpoint
)
=
args
.otlp_endpoint
{
arg
v
.push
(
"--otlp-endpoint"
.to_string
());
router_
arg
s
.push
(
"--otlp-endpoint"
.to_string
());
arg
v
.push
(
otlp_endpoint
);
router_
arg
s
.push
(
otlp_endpoint
);
}
}
// CORS origins
// CORS origins
for
origin
in
args
.cors_allow_origin
.into_iter
()
{
for
origin
in
args
.cors_allow_origin
.into_iter
()
{
arg
v
.push
(
"--cors-allow-origin"
.to_string
());
router_
arg
s
.push
(
"--cors-allow-origin"
.to_string
());
arg
v
.push
(
origin
);
router_
arg
s
.push
(
origin
);
}
}
// Ngrok
// Ngrok
...
@@ -904,34 +901,34 @@ fn spawn_webserver(
...
@@ -904,34 +901,34 @@ fn spawn_webserver(
LauncherError
::
WebserverCannotStart
LauncherError
::
WebserverCannotStart
})
?
;
})
?
;
arg
v
.push
(
"--ngrok"
.to_string
());
router_
arg
s
.push
(
"--ngrok"
.to_string
());
arg
v
.push
(
"--ngrok-authtoken"
.to_string
());
router_
arg
s
.push
(
"--ngrok-authtoken"
.to_string
());
arg
v
.push
(
authtoken
);
router_
arg
s
.push
(
authtoken
);
if
let
Some
(
domain
)
=
args
.ngrok_domain
{
if
let
Some
(
domain
)
=
args
.ngrok_domain
{
arg
v
.push
(
"--ngrok-domain"
.to_string
());
router_
arg
s
.push
(
"--ngrok-domain"
.to_string
());
arg
v
.push
(
domain
);
router_
arg
s
.push
(
domain
);
}
}
if
let
(
Some
(
username
),
Some
(
password
))
=
(
args
.ngrok_username
,
args
.ngrok_password
)
{
if
let
(
Some
(
username
),
Some
(
password
))
=
(
args
.ngrok_username
,
args
.ngrok_password
)
{
arg
v
.push
(
"--ngrok-username"
.to_string
());
router_
arg
s
.push
(
"--ngrok-username"
.to_string
());
arg
v
.push
(
username
);
router_
arg
s
.push
(
username
);
arg
v
.push
(
"--ngrok-password"
.to_string
());
router_
arg
s
.push
(
"--ngrok-password"
.to_string
());
arg
v
.push
(
password
);
router_
arg
s
.push
(
password
);
}
}
}
}
// Copy current process env
// Copy current process env
let
mut
env
:
Vec
<
(
OsString
,
OsString
)
>
=
env
::
vars_os
()
.collect
();
let
mut
env
s
:
Vec
<
(
OsString
,
OsString
)
>
=
env
::
vars_os
()
.collect
();
// Parse Inference API token
// Parse Inference API token
if
let
Ok
(
api_token
)
=
env
::
var
(
"HF_API_TOKEN"
)
{
if
let
Ok
(
api_token
)
=
env
::
var
(
"HF_API_TOKEN"
)
{
env
.push
((
"HUGGING_FACE_HUB_TOKEN"
.into
(),
api_token
.into
()))
env
s
.push
((
"HUGGING_FACE_HUB_TOKEN"
.into
(),
api_token
.into
()))
};
};
let
mut
webserver
=
match
Command
::
new
(
"text-generation-router"
)
let
mut
webserver
=
match
Command
::
new
(
"text-generation-router"
)
.args
(
arg
v
)
.args
(
router_
arg
s
)
.envs
(
env
)
.envs
(
env
s
)
.stdout
(
Stdio
::
piped
())
.stdout
(
Stdio
::
piped
())
.stderr
(
Stdio
::
piped
())
.stderr
(
Stdio
::
piped
())
.process_group
(
0
)
.process_group
(
0
)
...
@@ -969,6 +966,31 @@ fn spawn_webserver(
...
@@ -969,6 +966,31 @@ fn spawn_webserver(
Ok
(
webserver
)
Ok
(
webserver
)
}
}
fn
terminate
(
process_name
:
&
str
,
mut
process
:
Child
,
timeout
:
Duration
)
->
io
::
Result
<
ExitStatus
>
{
tracing
::
info!
(
"Terminating {process_name}"
);
let
terminate_time
=
Instant
::
now
();
signal
::
kill
(
Pid
::
from_raw
(
process
.id
()
as
i32
),
Signal
::
SIGTERM
)
.unwrap
();
tracing
::
info!
(
"Waiting for {process_name} to gracefully shutdown"
);
while
terminate_time
.elapsed
()
<
timeout
{
if
let
Some
(
status
)
=
process
.try_wait
()
?
{
tracing
::
info!
(
"{process_name} terminated"
);
return
Ok
(
status
);
}
sleep
(
Duration
::
from_millis
(
100
));
}
tracing
::
info!
(
"Killing {process_name}"
);
process
.kill
()
?
;
let
exit_status
=
process
.wait
()
?
;
tracing
::
info!
(
"{process_name} killed"
);
Ok
(
exit_status
)
}
fn
main
()
->
Result
<
(),
LauncherError
>
{
fn
main
()
->
Result
<
(),
LauncherError
>
{
// Pattern match configuration
// Pattern match configuration
let
args
=
Args
::
parse
();
let
args
=
Args
::
parse
();
...
@@ -1038,6 +1060,11 @@ fn main() -> Result<(), LauncherError> {
...
@@ -1038,6 +1060,11 @@ fn main() -> Result<(), LauncherError> {
// Download and convert model weights
// Download and convert model weights
download_convert_model
(
&
args
,
running
.clone
())
?
;
download_convert_model
(
&
args
,
running
.clone
())
?
;
if
!
running
.load
(
Ordering
::
SeqCst
)
{
// Launcher was asked to stop
return
Ok
(());
}
// Shared shutdown bool
// Shared shutdown bool
let
shutdown
=
Arc
::
new
(
AtomicBool
::
new
(
false
));
let
shutdown
=
Arc
::
new
(
AtomicBool
::
new
(
false
));
// Shared shutdown channel
// Shared shutdown channel
...
@@ -1096,10 +1123,7 @@ fn main() -> Result<(), LauncherError> {
...
@@ -1096,10 +1123,7 @@ fn main() -> Result<(), LauncherError> {
}
}
// Graceful termination
// Graceful termination
signal
::
kill
(
Pid
::
from_raw
(
webserver
.id
()
as
i32
),
Signal
::
SIGTERM
)
.unwrap
();
terminate
(
"webserver"
,
webserver
,
Duration
::
from_secs
(
90
))
.unwrap
();
tracing
::
info!
(
"Waiting for webserver to gracefully shutdown"
);
webserver
.wait
()
.unwrap
();
tracing
::
info!
(
"Webserver terminated"
);
shutdown_shards
(
shutdown
,
&
shutdown_receiver
);
shutdown_shards
(
shutdown
,
&
shutdown_receiver
);
exit_code
exit_code
...
...
server/pyproject.toml
View file @
c58a0c18
[tool.poetry]
[tool.poetry]
name
=
"text-generation-server"
name
=
"text-generation-server"
version
=
"0.9.
1
"
version
=
"0.9.
2
"
description
=
"Text Generation Inference Python gRPC Server"
description
=
"Text Generation Inference Python gRPC Server"
authors
=
[
"Olivier Dehaene <olivier@huggingface.co>"
]
authors
=
[
"Olivier Dehaene <olivier@huggingface.co>"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment