Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
77758f60
Unverified
Commit
77758f60
authored
Apr 26, 2023
by
Nicolas Patry
Committed by
GitHub
Apr 26, 2023
Browse files
chore(launcher): refactor logic (#242)
Hopefully it's cleaner
parent
7de8a377
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
550 additions
and
543 deletions
+550
-543
launcher/src/main.rs
launcher/src/main.rs
+550
-543
No files found.
launcher/src/main.rs
View file @
77758f60
...
@@ -4,7 +4,6 @@ use std::env;
...
@@ -4,7 +4,6 @@ use std::env;
use
std
::
ffi
::
OsString
;
use
std
::
ffi
::
OsString
;
use
std
::
io
::{
BufRead
,
BufReader
,
Read
};
use
std
::
io
::{
BufRead
,
BufReader
,
Read
};
use
std
::
path
::
Path
;
use
std
::
path
::
Path
;
use
std
::
process
::
ExitCode
;
use
std
::
sync
::
atomic
::{
AtomicBool
,
Ordering
};
use
std
::
sync
::
atomic
::{
AtomicBool
,
Ordering
};
use
std
::
sync
::
mpsc
::
TryRecvError
;
use
std
::
sync
::
mpsc
::
TryRecvError
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
...
@@ -73,248 +72,454 @@ struct Args {
...
@@ -73,248 +72,454 @@ struct Args {
watermark_delta
:
Option
<
f32
>
,
watermark_delta
:
Option
<
f32
>
,
}
}
fn
main
()
->
ExitCode
{
#[derive(Debug)]
// Pattern match configuration
enum
ShardStatus
{
let
args
=
Args
::
parse
();
Ready
,
Failed
((
usize
,
String
)),
}
if
args
.json_output
{
#[allow(clippy::too_many_arguments)]
tracing_subscriber
::
fmt
()
.json
()
.init
();
fn
shard_manager
(
}
else
{
model_id
:
String
,
tracing_subscriber
::
fmt
()
.compact
()
.init
();
revision
:
Option
<
String
>
,
quantize
:
bool
,
uds_path
:
String
,
rank
:
usize
,
world_size
:
usize
,
master_addr
:
String
,
master_port
:
usize
,
huggingface_hub_cache
:
Option
<
String
>
,
weights_cache_override
:
Option
<
String
>
,
disable_custom_kernels
:
bool
,
watermark_gamma
:
Option
<
f32
>
,
watermark_delta
:
Option
<
f32
>
,
otlp_endpoint
:
Option
<
String
>
,
status_sender
:
mpsc
::
Sender
<
ShardStatus
>
,
shutdown
:
Arc
<
Mutex
<
bool
>>
,
_
shutdown_sender
:
mpsc
::
Sender
<
()
>
,
)
{
// Get UDS path
let
uds_string
=
format!
(
"{uds_path}-{rank}"
);
let
uds
=
Path
::
new
(
&
uds_string
);
// Clean previous runs
fs
::
remove_file
(
uds
)
.unwrap_or_default
();
// Process args
let
mut
shard_argv
=
vec!
[
"text-generation-server"
.to_string
(),
"serve"
.to_string
(),
model_id
,
"--uds-path"
.to_string
(),
uds_path
,
"--logger-level"
.to_string
(),
"INFO"
.to_string
(),
"--json-output"
.to_string
(),
];
// Activate tensor parallelism
if
world_size
>
1
{
shard_argv
.push
(
"--sharded"
.to_string
());
}
}
tracing
::
info!
(
"{:?}"
,
args
);
if
quantize
{
shard_argv
.push
(
"--quantize"
.to_string
())
}
let
Args
{
// Model optional revision
model_id
,
if
let
Some
(
revision
)
=
revision
{
revision
,
shard_argv
.push
(
"--revision"
.to_string
());
sharded
,
shard_argv
.push
(
revision
)
num_shard
,
}
quantize
,
max_concurrent_requests
,
max_best_of
,
max_stop_sequences
,
max_input_length
,
max_total_tokens
,
max_batch_size
,
max_batch_total_tokens
,
waiting_served_ratio
,
max_waiting_tokens
,
port
,
shard_uds_path
,
master_addr
,
master_port
,
huggingface_hub_cache
,
weights_cache_override
,
disable_custom_kernels
,
json_output
,
otlp_endpoint
,
cors_allow_origin
,
watermark_gamma
,
watermark_delta
,
}
=
args
;
// get the number of shards given `sharded` and `num_shard`
// OpenTelemetry
let
num_shard
=
if
let
Some
(
sharded
)
=
sharded
{
if
let
Some
(
otlp_endpoint
)
=
otlp_endpoint
{
// sharded is set
shard_argv
.push
(
"--otlp-endpoint"
.to_string
());
match
sharded
{
shard_argv
.push
(
otlp_endpoint
);
// sharded is set and true
}
true
=>
{
match
num_shard
{
// Copy current process env
None
=>
{
let
mut
env
:
Vec
<
(
OsString
,
OsString
)
>
=
env
::
vars_os
()
.collect
();
// try to default to the number of available GPUs
tracing
::
info!
(
"Parsing num_shard from CUDA_VISIBLE_DEVICES"
);
// Torch Distributed Env vars
let
n_devices
=
num_cuda_devices
()
env
.push
((
"RANK"
.into
(),
rank
.to_string
()
.into
()));
.expect
(
"--num-shard and CUDA_VISIBLE_DEVICES are not set"
);
env
.push
((
"WORLD_SIZE"
.into
(),
world_size
.to_string
()
.into
()));
if
n_devices
<=
1
{
env
.push
((
"MASTER_ADDR"
.into
(),
master_addr
.into
()));
panic!
(
"`sharded` is true but only found {n_devices} CUDA devices"
);
env
.push
((
"MASTER_PORT"
.into
(),
master_port
.to_string
()
.into
()));
}
env
.push
((
"NCCL_ASYNC_ERROR_HANDLING"
.into
(),
"1"
.into
()));
n_devices
}
// Safetensors load fast
Some
(
num_shard
)
=>
{
env
.push
((
"SAFETENSORS_FAST_GPU"
.into
(),
"1"
.into
()));
// we can't have only one shard while sharded
if
num_shard
<=
1
{
// Enable hf transfer for insane download speeds
panic!
(
"`sharded` is true but `num_shard` <= 1"
);
let
enable_hf_transfer
=
env
::
var
(
"HF_HUB_ENABLE_HF_TRANSFER"
)
.unwrap_or
(
"1"
.to_string
());
}
env
.push
((
num_shard
"HF_HUB_ENABLE_HF_TRANSFER"
.into
(),
}
enable_hf_transfer
.into
(),
));
// Parse Inference API token
if
let
Ok
(
api_token
)
=
env
::
var
(
"HF_API_TOKEN"
)
{
env
.push
((
"HUGGING_FACE_HUB_TOKEN"
.into
(),
api_token
.into
()))
};
// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
if
let
Some
(
huggingface_hub_cache
)
=
huggingface_hub_cache
{
env
.push
((
"HUGGINGFACE_HUB_CACHE"
.into
(),
huggingface_hub_cache
.into
()));
};
// If weights_cache_override is some, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint
if
let
Some
(
weights_cache_override
)
=
weights_cache_override
{
env
.push
((
"WEIGHTS_CACHE_OVERRIDE"
.into
(),
weights_cache_override
.into
(),
));
};
// If disable_custom_kernels is true, pass it to the shard as an env var
if
disable_custom_kernels
{
env
.push
((
"DISABLE_CUSTOM_KERNELS"
.into
(),
"True"
.into
()))
}
// Watermark Gamma
if
let
Some
(
watermark_gamma
)
=
watermark_gamma
{
env
.push
((
"WATERMARK_GAMMA"
.into
(),
watermark_gamma
.to_string
()
.into
()))
}
// Watermark Delta
if
let
Some
(
watermark_delta
)
=
watermark_delta
{
env
.push
((
"WATERMARK_DELTA"
.into
(),
watermark_delta
.to_string
()
.into
()))
}
// Start process
tracing
::
info!
(
"Starting shard {rank}"
);
let
mut
p
=
match
Popen
::
create
(
&
shard_argv
,
PopenConfig
{
stdout
:
Redirection
::
Pipe
,
stderr
:
Redirection
::
Pipe
,
// Needed for the shutdown procedure
setpgid
:
true
,
// NCCL env vars
env
:
Some
(
env
),
..
Default
::
default
()
},
)
{
Ok
(
p
)
=>
p
,
Err
(
err
)
=>
{
if
let
PopenError
::
IoError
(
ref
err
)
=
err
{
if
err
.kind
()
==
io
::
ErrorKind
::
NotFound
{
tracing
::
error!
(
"text-generation-server not found in PATH"
);
tracing
::
error!
(
"Please install it with `make install-server`"
)
}
}
}
}
// sharded is set and false
status_sender
false
=>
{
.send
(
ShardStatus
::
Failed
((
rank
,
err
.to_string
())))
let
num_shard
=
num_shard
.unwrap_or
(
1
);
.unwrap
();
// we can't have more than one shard while not sharded
return
;
if
num_shard
!=
1
{
}
panic!
(
"`sharded` is false but `num_shard` != 1"
);
};
}
num_shard
// Redirect STDOUT to the console
let
shard_stdout
=
p
.stdout
.take
()
.unwrap
();
thread
::
spawn
(
move
||
{
// Enter shard-manager tracing span
let
stdout
=
BufReader
::
new
(
shard_stdout
);
let
_
span
=
tracing
::
span!
(
tracing
::
Level
::
INFO
,
"shard-manager"
,
rank
=
rank
)
.entered
();
for
line
in
stdout
.lines
()
{
// Parse loguru logs
if
let
Ok
(
log
)
=
serde_json
::
from_str
::
<
PythonLogMessage
>
(
&
line
.unwrap
())
{
log
.trace
();
}
}
});
let
mut
ready
=
false
;
let
start_time
=
Instant
::
now
();
let
mut
wait_time
=
Instant
::
now
();
loop
{
// Process exited
if
p
.poll
()
.is_some
()
{
let
mut
err
=
String
::
new
();
p
.stderr
.take
()
.unwrap
()
.read_to_string
(
&
mut
err
)
.unwrap
();
status_sender
.send
(
ShardStatus
::
Failed
((
rank
,
err
)))
.unwrap
();
return
;
}
// We received a shutdown signal
if
*
shutdown
.lock
()
.unwrap
()
{
p
.terminate
()
.unwrap
();
let
_
=
p
.wait_timeout
(
Duration
::
from_secs
(
90
));
tracing
::
info!
(
"Shard {rank} terminated"
);
return
;
}
// Shard is ready
if
uds
.exists
()
&&
!
ready
{
tracing
::
info!
(
"Shard {rank} ready in {:?}"
,
start_time
.elapsed
());
status_sender
.send
(
ShardStatus
::
Ready
)
.unwrap
();
ready
=
true
;
}
else
if
!
ready
&&
wait_time
.elapsed
()
>
Duration
::
from_secs
(
10
)
{
tracing
::
info!
(
"Waiting for shard {rank} to be ready..."
);
wait_time
=
Instant
::
now
();
}
sleep
(
Duration
::
from_millis
(
100
));
}
}
fn
shutdown_shards
(
shutdown
:
Arc
<
Mutex
<
bool
>>
,
shutdown_receiver
:
&
mpsc
::
Receiver
<
()
>
)
{
tracing
::
info!
(
"Shutting down shards"
);
// Update shutdown value to true
// This will be picked up by the shard manager
{
let
mut
shutdown
=
shutdown
.lock
()
.unwrap
();
*
shutdown
=
true
;
}
// Wait for shards to shutdown
// This will block till all shutdown_sender are dropped
let
_
=
shutdown_receiver
.recv
();
}
fn
num_cuda_devices
()
->
Option
<
usize
>
{
if
let
Ok
(
cuda_visible_devices
)
=
env
::
var
(
"CUDA_VISIBLE_DEVICES"
)
{
let
n_devices
=
cuda_visible_devices
.split
(
','
)
.count
();
return
Some
(
n_devices
);
}
None
}
#[derive(Deserialize)]
#[serde(rename_all
=
"UPPERCASE"
)]
enum
PythonLogLevelEnum
{
Trace
,
Debug
,
Info
,
Success
,
Warning
,
Error
,
Critical
,
}
#[derive(Deserialize)]
struct
PythonLogLevel
{
name
:
PythonLogLevelEnum
,
}
#[derive(Deserialize)]
struct
PythonLogRecord
{
level
:
PythonLogLevel
,
}
#[derive(Deserialize)]
struct
PythonLogMessage
{
text
:
String
,
record
:
PythonLogRecord
,
}
impl
PythonLogMessage
{
fn
trace
(
&
self
)
{
match
self
.record.level.name
{
PythonLogLevelEnum
::
Trace
=>
tracing
::
trace!
(
"{}"
,
self
.text
),
PythonLogLevelEnum
::
Debug
=>
tracing
::
debug!
(
"{}"
,
self
.text
),
PythonLogLevelEnum
::
Info
=>
tracing
::
info!
(
"{}"
,
self
.text
),
PythonLogLevelEnum
::
Success
=>
tracing
::
info!
(
"{}"
,
self
.text
),
PythonLogLevelEnum
::
Warning
=>
tracing
::
warn!
(
"{}"
,
self
.text
),
PythonLogLevelEnum
::
Error
=>
tracing
::
error!
(
"{}"
,
self
.text
),
PythonLogLevelEnum
::
Critical
=>
tracing
::
error!
(
"{}"
,
self
.text
),
}
}
}
fn
find_num_shards
(
sharded
:
Option
<
bool
>
,
num_shard
:
Option
<
usize
>
)
->
usize
{
// get the number of shards given `sharded` and `num_shard`
let
num_shard
=
match
(
sharded
,
num_shard
)
{
(
Some
(
true
),
None
)
=>
{
// try to default to the number of available GPUs
tracing
::
info!
(
"Parsing num_shard from CUDA_VISIBLE_DEVICES"
);
let
n_devices
=
num_cuda_devices
()
.expect
(
"--num-shard and CUDA_VISIBLE_DEVICES are not set"
);
if
n_devices
<=
1
{
panic!
(
"`sharded` is true but only found {n_devices} CUDA devices"
);
}
}
n_devices
}
}
}
else
{
(
Some
(
true
),
Some
(
num_shard
))
=>
{
match
num_shard
{
// we can't have only one shard while sharded
// get num_shard from CUDA_VISIBLE_DEVICES or default to a single shard
if
num_shard
<=
1
{
None
=>
num_cuda_devices
()
.unwrap_or
(
1
),
panic!
(
"`sharded` is true but `num_shard` <= 1"
);
Some
(
num_shard
)
=>
num_shard
,
}
num_shard
}
}
(
Some
(
false
),
Some
(
num_shard
))
=>
num_shard
,
(
Some
(
false
),
None
)
=>
1
,
(
None
,
None
)
=>
num_cuda_devices
()
.unwrap_or
(
1
),
(
None
,
Some
(
num_shard
))
=>
num_shard
,
};
};
if
num_shard
<
1
{
if
num_shard
<
1
{
panic!
(
"`num_shard` cannot be < 1"
);
panic!
(
"`num_shard` cannot be < 1"
);
}
}
num_shard
}
if
num_shard
>
1
{
#[derive(Debug)]
tracing
::
info!
(
"Sharding model on {num_shard} processes"
);
enum
LauncherError
{
}
DownloadError
,
ShardCannotStart
,
// Signal handler
ShardDisconnected
,
let
running
=
Arc
::
new
(
AtomicBool
::
new
(
true
));
ShardFailed
,
let
r
=
running
.clone
();
WebserverFailed
,
ctrlc
::
set_handler
(
move
||
{
WebserverCannotStart
,
r
.store
(
false
,
Ordering
::
SeqCst
);
}
})
.expect
(
"Error setting Ctrl-C handler"
);
// Check if model_id is a local model
fn
download_model
(
args
:
&
Args
,
running
:
Arc
<
AtomicBool
>
)
->
Result
<
(),
LauncherError
>
{
let
local_path
=
Path
::
new
(
&
model_id
);
let
mut
download_argv
=
vec!
[
let
is_local_model
=
local_path
.exists
()
&&
local_path
.is_dir
();
"text-generation-server"
.to_string
(),
"download-weights"
.to_string
(),
args
.model_id
.to_string
(),
"--extension"
.to_string
(),
".safetensors"
.to_string
(),
"--logger-level"
.to_string
(),
"INFO"
.to_string
(),
"--json-output"
.to_string
(),
];
// Download weights for sharded models
// Model optional revision
if
!
is_local_model
&&
weights_cache_override
.is_none
()
&&
num_shard
>
1
{
if
let
Some
(
revision
)
=
&
args
.revision
{
let
mut
download_argv
=
vec!
[
download_argv
.push
(
"--revision"
.to_string
());
"text-generation-server"
.to_string
(),
download_argv
.push
(
revision
.to_string
())
"download-weights"
.to_string
(),
}
model_id
.clone
(),
"--extension"
.to_string
(),
".safetensors"
.to_string
(),
"--logger-level"
.to_string
(),
"INFO"
.to_string
(),
"--json-output"
.to_string
(),
];
// Model optional revision
if
let
Some
(
ref
revision
)
=
revision
{
download_argv
.push
(
"--revision"
.to_string
());
download_argv
.push
(
revision
.to_string
())
}
// Copy current process env
// Copy current process env
let
mut
env
:
Vec
<
(
OsString
,
OsString
)
>
=
env
::
vars_os
()
.collect
();
let
mut
env
:
Vec
<
(
OsString
,
OsString
)
>
=
env
::
vars_os
()
.collect
();
// If huggingface_hub_cache is set, pass it to the shard
// If huggingface_hub_cache is set, pass it to the shard
// Useful when running inside a docker container
// Useful when running inside a docker container
if
let
Some
(
ref
huggingface_hub_cache
)
=
huggingface_hub_cache
{
if
let
Some
(
ref
huggingface_hub_cache
)
=
args
.
huggingface_hub_cache
{
env
.push
((
"HUGGINGFACE_HUB_CACHE"
.into
(),
huggingface_hub_cache
.into
()));
env
.push
((
"HUGGINGFACE_HUB_CACHE"
.into
(),
huggingface_hub_cache
.into
()));
};
};
// Enable hf transfer for insane download speeds
// Enable hf transfer for insane download speeds
let
enable_hf_transfer
=
env
::
var
(
"HF_HUB_ENABLE_HF_TRANSFER"
)
.unwrap_or
(
"1"
.to_string
());
let
enable_hf_transfer
=
env
::
var
(
"HF_HUB_ENABLE_HF_TRANSFER"
)
.unwrap_or
(
"1"
.to_string
());
env
.push
((
env
.push
((
"HF_HUB_ENABLE_HF_TRANSFER"
.into
(),
"HF_HUB_ENABLE_HF_TRANSFER"
.into
(),
enable_hf_transfer
.into
(),
enable_hf_transfer
.into
(),
));
));
// Parse Inference API token
// Parse Inference API token
if
let
Ok
(
api_token
)
=
env
::
var
(
"HF_API_TOKEN"
)
{
if
let
Ok
(
api_token
)
=
env
::
var
(
"HF_API_TOKEN"
)
{
env
.push
((
"HUGGING_FACE_HUB_TOKEN"
.into
(),
api_token
.into
()))
env
.push
((
"HUGGING_FACE_HUB_TOKEN"
.into
(),
api_token
.into
()))
};
};
// Start process
// Start process
tracing
::
info!
(
"Starting download process."
);
tracing
::
info!
(
"Starting download process."
);
let
mut
download_process
=
match
Popen
::
create
(
let
mut
download_process
=
match
Popen
::
create
(
&
download_argv
,
&
download_argv
,
PopenConfig
{
PopenConfig
{
stdout
:
Redirection
::
Pipe
,
stdout
:
Redirection
::
Pipe
,
stderr
:
Redirection
::
Pipe
,
stderr
:
Redirection
::
Pipe
,
// Needed for the shutdown procedure
// Needed for the shutdown procedure
setpgid
:
true
,
setpgid
:
true
,
env
:
Some
(
env
),
env
:
Some
(
env
),
..
Default
::
default
()
..
Default
::
default
()
},
},
)
{
)
{
Ok
(
p
)
=>
p
,
Ok
(
p
)
=>
p
,
Err
(
err
)
=>
{
Err
(
err
)
=>
{
if
let
PopenError
::
IoError
(
ref
err
)
=
err
{
if
let
PopenError
::
IoError
(
ref
err
)
=
err
{
if
err
.kind
()
==
io
::
ErrorKind
::
NotFound
{
if
err
.kind
()
==
io
::
ErrorKind
::
NotFound
{
tracing
::
error!
(
"text-generation-server not found in PATH"
);
tracing
::
error!
(
"text-generation-server not found in PATH"
);
tracing
::
error!
(
"Please install it with `make install-server`"
)
tracing
::
error!
(
"Please install it with `make install-server`"
)
}
}
}
return
ExitCode
::
FAILURE
;
}
}
};
return
Err
(
LauncherError
::
DownloadError
);
}
};
// Redirect STDOUT to the console
// Redirect STDOUT to the console
let
download_stdout
=
download_process
.stdout
.take
()
.unwrap
();
let
download_stdout
=
download_process
.stdout
.take
()
.unwrap
();
thread
::
spawn
(
move
||
{
thread
::
spawn
(
move
||
{
// Enter download tracing span
// Enter download tracing span
let
stdout
=
BufReader
::
new
(
download_stdout
);
let
stdout
=
BufReader
::
new
(
download_stdout
);
let
_
span
=
tracing
::
span!
(
tracing
::
Level
::
INFO
,
"download"
)
.entered
();
let
_
span
=
tracing
::
span!
(
tracing
::
Level
::
INFO
,
"download"
)
.entered
();
for
line
in
stdout
.lines
()
{
for
line
in
stdout
.lines
()
{
// Parse loguru logs
// Parse loguru logs
if
let
Ok
(
log
)
=
serde_json
::
from_str
::
<
PythonLogMessage
>
(
&
line
.unwrap
())
{
if
let
Ok
(
log
)
=
serde_json
::
from_str
::
<
PythonLogMessage
>
(
&
line
.unwrap
())
{
log
.trace
();
log
.trace
();
}
}
}
});
}
});
loop
{
loop
{
if
let
Some
(
status
)
=
download_process
.poll
()
{
if
let
Some
(
status
)
=
download_process
.poll
()
{
match
status
{
match
status
{
ExitStatus
::
Exited
(
exit_code
)
=>
{
ExitStatus
::
Exited
(
exit_code
)
=>
{
if
exit_code
==
0
{
if
exit_code
==
0
{
tracing
::
info!
(
"Successfully downloaded weights."
);
tracing
::
info!
(
"Successfully downloaded weights."
);
break
;
break
;
}
else
{
}
else
{
let
mut
err
=
String
::
new
();
let
mut
err
=
String
::
new
();
download_process
download_process
.stderr
.stderr
.take
()
.take
()
.unwrap
()
.unwrap
()
.read_to_string
(
&
mut
err
)
.read_to_string
(
&
mut
err
)
.unwrap
();
.unwrap
();
tracing
::
error!
(
"Download encountered an error: {err}"
);
tracing
::
error!
(
"Download encountered an error: {err}"
);
return
ExitCode
::
FAILURE
;
return
Err
(
LauncherError
::
DownloadError
);
}
}
_
=>
{
tracing
::
error!
(
"Download process exited with an unknown status."
);
return
ExitCode
::
FAILURE
;
}
}
}
}
_
=>
{
tracing
::
error!
(
"Download process exited with an unknown status."
);
return
Err
(
LauncherError
::
DownloadError
);
}
}
}
if
!
running
.load
(
Ordering
::
SeqCst
)
{
download_process
.terminate
()
.unwrap
();
tracing
::
info!
(
"Waiting for download process to gracefully shutdown"
);
download_process
.wait_timeout
(
Duration
::
from_secs
(
90
))
.unwrap
();
tracing
::
info!
(
"Download process terminated"
);
return
ExitCode
::
SUCCESS
;
}
sleep
(
Duration
::
from_millis
(
100
));
}
}
if
!
running
.load
(
Ordering
::
SeqCst
)
{
download_process
.terminate
()
.unwrap
();
tracing
::
info!
(
"Waiting for download process to gracefully shutdown"
);
download_process
.wait_timeout
(
Duration
::
from_secs
(
90
))
.unwrap
();
tracing
::
info!
(
"Download process terminated"
);
return
Ok
(());
}
sleep
(
Duration
::
from_millis
(
100
));
}
}
Ok
(())
}
// Shared shutdown bool
fn
spawn_shards
(
let
shutdown
=
Arc
::
new
(
Mutex
::
new
(
false
));
num_shard
:
usize
,
// Shared shutdown channel
args
:
&
Args
,
// When shutting down, the main thread will wait for all senders to be dropped
shutdown
:
Arc
<
Mutex
<
bool
>>
,
let
(
shutdown_sender
,
shutdown_receiver
)
=
mpsc
::
channel
();
shutdown_receiver
:
&
mpsc
::
Receiver
<
()
>
,
shutdown_sender
:
mpsc
::
Sender
<
()
>
,
// Shared channel to track shard status
status_receiver
:
&
mpsc
::
Receiver
<
ShardStatus
>
,
let
(
status_sender
,
status_receiver
)
=
mpsc
::
channel
();
status_sender
:
mpsc
::
Sender
<
ShardStatus
>
,
running
:
Arc
<
AtomicBool
>
,
)
->
Result
<
(),
LauncherError
>
{
// Start shard processes
// Start shard processes
for
rank
in
0
..
num_shard
{
for
rank
in
0
..
num_shard
{
let
model_id
=
model_id
.clone
();
let
model_id
=
args
.
model_id
.clone
();
let
revision
=
revision
.clone
();
let
revision
=
args
.
revision
.clone
();
let
uds_path
=
shard_uds_path
.clone
();
let
uds_path
=
args
.
shard_uds_path
.clone
();
let
master_addr
=
master_addr
.clone
();
let
master_addr
=
args
.
master_addr
.clone
();
let
huggingface_hub_cache
=
huggingface_hub_cache
.clone
();
let
huggingface_hub_cache
=
args
.
huggingface_hub_cache
.clone
();
let
weights_cache_override
=
weights_cache_override
.clone
();
let
weights_cache_override
=
args
.
weights_cache_override
.clone
();
let
status_sender
=
status_sender
.clone
();
let
status_sender
=
status_sender
.clone
();
let
shutdown
=
shutdown
.clone
();
let
shutdown
=
shutdown
.clone
();
let
shutdown_sender
=
shutdown_sender
.clone
();
let
shutdown_sender
=
shutdown_sender
.clone
();
let
otlp_endpoint
=
otlp_endpoint
.clone
();
let
otlp_endpoint
=
args
.otlp_endpoint
.clone
();
let
quantize
=
args
.quantize
.clone
();
let
master_port
=
args
.master_port
.clone
();
let
disable_custom_kernels
=
args
.disable_custom_kernels
.clone
();
let
watermark_gamma
=
args
.watermark_gamma
.clone
();
let
watermark_delta
=
args
.watermark_delta
.clone
();
thread
::
spawn
(
move
||
{
thread
::
spawn
(
move
||
{
shard_manager
(
shard_manager
(
model_id
,
model_id
,
...
@@ -355,422 +560,224 @@ fn main() -> ExitCode {
...
@@ -355,422 +560,224 @@ fn main() -> ExitCode {
Ok
(
ShardStatus
::
Failed
((
rank
,
err
)))
=>
{
Ok
(
ShardStatus
::
Failed
((
rank
,
err
)))
=>
{
tracing
::
error!
(
"Shard {} failed to start:
\n
{}"
,
rank
,
err
);
tracing
::
error!
(
"Shard {} failed to start:
\n
{}"
,
rank
,
err
);
shutdown_shards
(
shutdown
,
&
shutdown_receiver
);
shutdown_shards
(
shutdown
,
&
shutdown_receiver
);
return
E
xitCode
::
FAILURE
;
return
E
rr
(
LauncherError
::
ShardCannotStart
)
;
}
}
Err
(
TryRecvError
::
Disconnected
)
=>
{
Err
(
TryRecvError
::
Disconnected
)
=>
{
tracing
::
error!
(
"Shard status channel disconnected"
);
tracing
::
error!
(
"Shard status channel disconnected"
);
shutdown_shards
(
shutdown
,
&
shutdown_receiver
);
shutdown_shards
(
shutdown
,
&
shutdown_receiver
);
return
E
xitCode
::
FAILURE
;
return
E
rr
(
LauncherError
::
ShardDisconnected
)
;
}
}
}
}
}
}
Ok
(())
}
// We might have received a termination signal
fn
spawn_webserver
(
if
!
running
.load
(
Ordering
::
SeqCst
)
{
args
:
Args
,
shutdown_shards
(
shutdown
,
&
shutdown_receiver
);
shutdown
:
Arc
<
Mutex
<
bool
>>
,
return
ExitCode
::
SUCCESS
;
shutdown_receiver
:
&
mpsc
::
Receiver
<
()
>
,
}
)
->
Result
<
Popen
,
LauncherError
>
{
// All shard started
// All shard started
// Start webserver
// Start webserver
tracing
::
info!
(
"Starting Webserver"
);
tracing
::
info!
(
"Starting Webserver"
);
let
mut
argv
=
vec!
[
let
mut
argv
=
vec!
[
"text-generation-router"
.to_string
(),
"text-generation-router"
.to_string
(),
"--max-concurrent-requests"
.to_string
(),
"--max-concurrent-requests"
.to_string
(),
max_concurrent_requests
.to_string
(),
args
.
max_concurrent_requests
.to_string
(),
"--max-best-of"
.to_string
(),
"--max-best-of"
.to_string
(),
max_best_of
.to_string
(),
args
.
max_best_of
.to_string
(),
"--max-stop-sequences"
.to_string
(),
"--max-stop-sequences"
.to_string
(),
max_stop_sequences
.to_string
(),
args
.
max_stop_sequences
.to_string
(),
"--max-input-length"
.to_string
(),
"--max-input-length"
.to_string
(),
max_input_length
.to_string
(),
args
.
max_input_length
.to_string
(),
"--max-total-tokens"
.to_string
(),
"--max-total-tokens"
.to_string
(),
max_total_tokens
.to_string
(),
args
.
max_total_tokens
.to_string
(),
"--waiting-served-ratio"
.to_string
(),
"--waiting-served-ratio"
.to_string
(),
waiting_served_ratio
.to_string
(),
args
.
waiting_served_ratio
.to_string
(),
"--max-waiting-tokens"
.to_string
(),
"--max-waiting-tokens"
.to_string
(),
max_waiting_tokens
.to_string
(),
args
.
max_waiting_tokens
.to_string
(),
"--port"
.to_string
(),
"--port"
.to_string
(),
port
.to_string
(),
args
.
port
.to_string
(),
"--master-shard-uds-path"
.to_string
(),
"--master-shard-uds-path"
.to_string
(),
format!
(
"{shard_uds_path
}-0"
),
format!
(
"{
}-0"
,
args
.
shard_uds_path
),
"--tokenizer-name"
.to_string
(),
"--tokenizer-name"
.to_string
(),
model_id
,
args
.
model_id
,
];
];
// Deprecate max_batch_size
// Deprecate max_batch_size
if
let
Some
(
max_batch_size
)
=
max_batch_size
{
if
let
Some
(
max_batch_size
)
=
args
.
max_batch_size
{
argv
.push
(
"--max-batch-size"
.to_string
());
argv
.push
(
"--max-batch-size"
.to_string
());
argv
.push
(
max_batch_size
.to_string
())
argv
.push
(
max_batch_size
.to_string
())
}
else
{
}
else
{
argv
.push
(
"--max-batch-total-tokens"
.to_string
());
argv
.push
(
"--max-batch-total-tokens"
.to_string
());
argv
.push
(
max_batch_total_tokens
.to_string
())
argv
.push
(
args
.max_batch_total_tokens
.to_string
())
}
// Model optional revision
if
let
Some
(
ref
revision
)
=
revision
{
argv
.push
(
"--revision"
.to_string
());
argv
.push
(
revision
.to_string
())
}
if
json_output
{
argv
.push
(
"--json-output"
.to_string
());
}
// OpenTelemetry
if
let
Some
(
otlp_endpoint
)
=
otlp_endpoint
{
argv
.push
(
"--otlp-endpoint"
.to_string
());
argv
.push
(
otlp_endpoint
);
}
// CORS origins
for
origin
in
cors_allow_origin
.into_iter
()
{
argv
.push
(
"--cors-allow-origin"
.to_string
());
argv
.push
(
origin
);
}
// Copy current process env
let
mut
env
:
Vec
<
(
OsString
,
OsString
)
>
=
env
::
vars_os
()
.collect
();
// Parse Inference API token
if
let
Ok
(
api_token
)
=
env
::
var
(
"HF_API_TOKEN"
)
{
env
.push
((
"HUGGING_FACE_HUB_TOKEN"
.into
(),
api_token
.into
()))
};
let
mut
webserver
=
match
Popen
::
create
(
&
argv
,
PopenConfig
{
stdout
:
Redirection
::
Pipe
,
stderr
:
Redirection
::
Pipe
,
// Needed for the shutdown procedure
setpgid
:
true
,
env
:
Some
(
env
),
..
Default
::
default
()
},
)
{
Ok
(
p
)
=>
p
,
Err
(
err
)
=>
{
tracing
::
error!
(
"Failed to start webserver: {}"
,
err
);
if
let
PopenError
::
IoError
(
err
)
=
err
{
if
err
.kind
()
==
io
::
ErrorKind
::
NotFound
{
tracing
::
error!
(
"text-generation-router not found in PATH"
);
tracing
::
error!
(
"Please install it with `make install-router`"
)
}
}
else
{
tracing
::
error!
(
"{}"
,
err
);
}
shutdown_shards
(
shutdown
,
&
shutdown_receiver
);
return
ExitCode
::
FAILURE
;
}
};
// Redirect STDOUT and STDERR to the console
let
webserver_stdout
=
webserver
.stdout
.take
()
.unwrap
();
let
webserver_stderr
=
webserver
.stderr
.take
()
.unwrap
();
thread
::
spawn
(
move
||
{
let
stdout
=
BufReader
::
new
(
webserver_stdout
);
let
stderr
=
BufReader
::
new
(
webserver_stderr
);
for
line
in
stdout
.lines
()
{
println!
(
"{}"
,
line
.unwrap
());
}
for
line
in
stderr
.lines
()
{
println!
(
"{}"
,
line
.unwrap
());
}
});
// Default exit code
let
mut
exit_code
=
ExitCode
::
SUCCESS
;
while
running
.load
(
Ordering
::
SeqCst
)
{
if
let
Ok
(
ShardStatus
::
Failed
((
rank
,
err
)))
=
status_receiver
.try_recv
()
{
tracing
::
error!
(
"Shard {rank} failed:
\n
{err}"
);
exit_code
=
ExitCode
::
FAILURE
;
break
;
};
match
webserver
.poll
()
{
Some
(
_
)
=>
{
tracing
::
error!
(
"Webserver Crashed"
);
shutdown_shards
(
shutdown
,
&
shutdown_receiver
);
return
ExitCode
::
FAILURE
;
}
None
=>
{
sleep
(
Duration
::
from_millis
(
100
));
}
};
}
// Graceful termination
webserver
.terminate
()
.unwrap
();
tracing
::
info!
(
"Waiting for webserver to gracefully shutdown"
);
webserver
.wait_timeout
(
Duration
::
from_secs
(
90
))
.unwrap
();
tracing
::
info!
(
"Webserver terminated"
);
shutdown_shards
(
shutdown
,
&
shutdown_receiver
);
exit_code
}
#[derive(Debug)]
enum
ShardStatus
{
Ready
,
Failed
((
usize
,
String
)),
}
#[allow(clippy::too_many_arguments)]
fn
shard_manager
(
model_id
:
String
,
revision
:
Option
<
String
>
,
quantize
:
bool
,
uds_path
:
String
,
rank
:
usize
,
world_size
:
usize
,
master_addr
:
String
,
master_port
:
usize
,
huggingface_hub_cache
:
Option
<
String
>
,
weights_cache_override
:
Option
<
String
>
,
disable_custom_kernels
:
bool
,
watermark_gamma
:
Option
<
f32
>
,
watermark_delta
:
Option
<
f32
>
,
otlp_endpoint
:
Option
<
String
>
,
status_sender
:
mpsc
::
Sender
<
ShardStatus
>
,
shutdown
:
Arc
<
Mutex
<
bool
>>
,
_
shutdown_sender
:
mpsc
::
Sender
<
()
>
,
)
{
// Get UDS path
let
uds_string
=
format!
(
"{uds_path}-{rank}"
);
let
uds
=
Path
::
new
(
&
uds_string
);
// Clean previous runs
fs
::
remove_file
(
uds
)
.unwrap_or_default
();
// Process args
let
mut
shard_argv
=
vec!
[
"text-generation-server"
.to_string
(),
"serve"
.to_string
(),
model_id
,
"--uds-path"
.to_string
(),
uds_path
,
"--logger-level"
.to_string
(),
"INFO"
.to_string
(),
"--json-output"
.to_string
(),
];
// Activate tensor parallelism
if
world_size
>
1
{
shard_argv
.push
(
"--sharded"
.to_string
());
}
}
if
quantize
{
// Model optional revision
shard_argv
.push
(
"--quantize"
.to_string
())
if
let
Some
(
ref
revision
)
=
args
.revision
{
argv
.push
(
"--revision"
.to_string
());
argv
.push
(
revision
.to_string
())
}
}
// Model optional revision
if
args
.json_output
{
if
let
Some
(
revision
)
=
revision
{
argv
.push
(
"--json-output"
.to_string
());
shard_argv
.push
(
"--revision"
.to_string
());
shard_argv
.push
(
revision
)
}
}
// OpenTelemetry
// OpenTelemetry
if
let
Some
(
otlp_endpoint
)
=
otlp_endpoint
{
if
let
Some
(
otlp_endpoint
)
=
args
.otlp_endpoint
{
shard_argv
.push
(
"--otlp-endpoint"
.to_string
());
argv
.push
(
"--otlp-endpoint"
.to_string
());
shard_argv
.push
(
otlp_endpoint
);
argv
.push
(
otlp_endpoint
);
}
// CORS origins
for
origin
in
args
.cors_allow_origin
.into_iter
()
{
argv
.push
(
"--cors-allow-origin"
.to_string
());
argv
.push
(
origin
);
}
}
// Copy current process env
// Copy current process env
let
mut
env
:
Vec
<
(
OsString
,
OsString
)
>
=
env
::
vars_os
()
.collect
();
let
mut
env
:
Vec
<
(
OsString
,
OsString
)
>
=
env
::
vars_os
()
.collect
();
// Torch Distributed Env vars
env
.push
((
"RANK"
.into
(),
rank
.to_string
()
.into
()));
env
.push
((
"WORLD_SIZE"
.into
(),
world_size
.to_string
()
.into
()));
env
.push
((
"MASTER_ADDR"
.into
(),
master_addr
.into
()));
env
.push
((
"MASTER_PORT"
.into
(),
master_port
.to_string
()
.into
()));
env
.push
((
"NCCL_ASYNC_ERROR_HANDLING"
.into
(),
"1"
.into
()));
// Safetensors load fast
env
.push
((
"SAFETENSORS_FAST_GPU"
.into
(),
"1"
.into
()));
// Enable hf transfer for insane download speeds
let
enable_hf_transfer
=
env
::
var
(
"HF_HUB_ENABLE_HF_TRANSFER"
)
.unwrap_or
(
"1"
.to_string
());
env
.push
((
"HF_HUB_ENABLE_HF_TRANSFER"
.into
(),
enable_hf_transfer
.into
(),
));
// Parse Inference API token
// Parse Inference API token
if
let
Ok
(
api_token
)
=
env
::
var
(
"HF_API_TOKEN"
)
{
if
let
Ok
(
api_token
)
=
env
::
var
(
"HF_API_TOKEN"
)
{
env
.push
((
"HUGGING_FACE_HUB_TOKEN"
.into
(),
api_token
.into
()))
env
.push
((
"HUGGING_FACE_HUB_TOKEN"
.into
(),
api_token
.into
()))
};
};
// If huggingface_hub_cache is some, pass it to the shard
let
mut
webserver
=
match
Popen
::
create
(
// Useful when running inside a docker container
&
argv
,
if
let
Some
(
huggingface_hub_cache
)
=
huggingface_hub_cache
{
env
.push
((
"HUGGINGFACE_HUB_CACHE"
.into
(),
huggingface_hub_cache
.into
()));
};
// If weights_cache_override is some, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint
if
let
Some
(
weights_cache_override
)
=
weights_cache_override
{
env
.push
((
"WEIGHTS_CACHE_OVERRIDE"
.into
(),
weights_cache_override
.into
(),
));
};
// If disable_custom_kernels is true, pass it to the shard as an env var
if
disable_custom_kernels
{
env
.push
((
"DISABLE_CUSTOM_KERNELS"
.into
(),
"True"
.into
()))
}
// Watermark Gamma
if
let
Some
(
watermark_gamma
)
=
watermark_gamma
{
env
.push
((
"WATERMARK_GAMMA"
.into
(),
watermark_gamma
.to_string
()
.into
()))
}
// Watermark Delta
if
let
Some
(
watermark_delta
)
=
watermark_delta
{
env
.push
((
"WATERMARK_DELTA"
.into
(),
watermark_delta
.to_string
()
.into
()))
}
// Start process
tracing
::
info!
(
"Starting shard {rank}"
);
let
mut
p
=
match
Popen
::
create
(
&
shard_argv
,
PopenConfig
{
PopenConfig
{
stdout
:
Redirection
::
Pipe
,
stdout
:
Redirection
::
Pipe
,
stderr
:
Redirection
::
Pipe
,
stderr
:
Redirection
::
Pipe
,
// Needed for the shutdown procedure
// Needed for the shutdown procedure
setpgid
:
true
,
setpgid
:
true
,
// NCCL env vars
env
:
Some
(
env
),
env
:
Some
(
env
),
..
Default
::
default
()
..
Default
::
default
()
},
},
)
{
)
{
Ok
(
p
)
=>
p
,
Ok
(
p
)
=>
p
,
Err
(
err
)
=>
{
Err
(
err
)
=>
{
if
let
PopenError
::
IoError
(
ref
err
)
=
err
{
tracing
::
error!
(
"Failed to start webserver: {}"
,
err
);
if
let
PopenError
::
IoError
(
err
)
=
err
{
if
err
.kind
()
==
io
::
ErrorKind
::
NotFound
{
if
err
.kind
()
==
io
::
ErrorKind
::
NotFound
{
tracing
::
error!
(
"text-generation-
serv
er not found in PATH"
);
tracing
::
error!
(
"text-generation-
rout
er not found in PATH"
);
tracing
::
error!
(
"Please install it with `make install-
serv
er`"
)
tracing
::
error!
(
"Please install it with `make install-
rout
er`"
)
}
}
}
else
{
tracing
::
error!
(
"{}"
,
err
);
}
}
status_sender
.send
(
ShardStatus
::
Failed
((
rank
,
err
.to_string
())))
shutdown_shards
(
shutdown
,
&
shutdown_receiver
);
.unwrap
();
return
Err
(
LauncherError
::
WebserverCannotStart
);
return
;
}
}
};
};
// Redirect STDOUT to the console
// Redirect STDOUT and STDERR to the console
let
shard_stdout
=
p
.stdout
.take
()
.unwrap
();
let
webserver_stdout
=
webserver
.stdout
.take
()
.unwrap
();
let
webserver_stderr
=
webserver
.stderr
.take
()
.unwrap
();
thread
::
spawn
(
move
||
{
thread
::
spawn
(
move
||
{
// Enter shard-manager tracing span
let
stdout
=
BufReader
::
new
(
webserver_stdout
);
let
stdout
=
BufReader
::
new
(
shard_stdout
);
let
stderr
=
BufReader
::
new
(
webserver_stderr
);
let
_
span
=
tracing
::
span!
(
tracing
::
Level
::
INFO
,
"shard-manager"
,
rank
=
rank
)
.entered
();
for
line
in
stdout
.lines
()
{
for
line
in
stdout
.lines
()
{
// Parse loguru logs
println!
(
"{}"
,
line
.unwrap
());
if
let
Ok
(
log
)
=
serde_json
::
from_str
::
<
PythonLogMessage
>
(
&
line
.unwrap
())
{
log
.trace
();
}
}
}
});
for
line
in
stderr
.lines
()
{
println!
(
"{}"
,
line
.unwrap
());
let
mut
ready
=
false
;
let
start_time
=
Instant
::
now
();
let
mut
wait_time
=
Instant
::
now
();
loop
{
// Process exited
if
p
.poll
()
.is_some
()
{
let
mut
err
=
String
::
new
();
p
.stderr
.take
()
.unwrap
()
.read_to_string
(
&
mut
err
)
.unwrap
();
status_sender
.send
(
ShardStatus
::
Failed
((
rank
,
err
)))
.unwrap
();
return
;
}
}
});
Ok
(
webserver
)
}
// We received a shutdown signal
fn
main
()
->
Result
<
(),
LauncherError
>
{
if
*
shutdown
.lock
()
.unwrap
()
{
// Pattern match configuration
p
.terminate
()
.unwrap
();
let
args
=
Args
::
parse
();
let
_
=
p
.wait_timeout
(
Duration
::
from_secs
(
90
));
tracing
::
info!
(
"Shard {rank} terminated"
);
return
;
}
// Shard is ready
if
args
.json_output
{
if
uds
.exists
()
&&
!
ready
{
tracing_subscriber
::
fmt
()
.json
()
.init
();
tracing
::
info!
(
"Shard {rank} ready in {:?}"
,
start_time
.elapsed
());
}
else
{
status_sender
.send
(
ShardStatus
::
Ready
)
.unwrap
();
tracing_subscriber
::
fmt
()
.compact
()
.init
();
ready
=
true
;
}
else
if
!
ready
&&
wait_time
.elapsed
()
>
Duration
::
from_secs
(
10
)
{
tracing
::
info!
(
"Waiting for shard {rank} to be ready..."
);
wait_time
=
Instant
::
now
();
}
sleep
(
Duration
::
from_millis
(
100
));
}
}
}
fn
shutdown_shards
(
shutdown
:
Arc
<
Mutex
<
bool
>>
,
shutdown_receiver
:
&
mpsc
::
Receiver
<
()
>
)
{
tracing
::
info!
(
"{:?}"
,
args
);
tracing
::
info!
(
"Shutting down shards"
);
// Update shutdown value to true
let
num_shard
=
find_num_shards
(
args
.sharded
,
args
.num_shard
);
// This will be picked up by the shard manager
if
num_shard
>
1
{
{
tracing
::
info!
(
"Sharding model on {num_shard} processes"
);
let
mut
shutdown
=
shutdown
.lock
()
.unwrap
();
*
shutdown
=
true
;
}
}
// Wait for shards to shutdown
// Signal handler
// This will block till all shutdown_sender are dropped
let
running
=
Arc
::
new
(
AtomicBool
::
new
(
true
));
let
_
=
shutdown_receiver
.recv
();
let
r
=
running
.clone
();
}
ctrlc
::
set_handler
(
move
||
{
r
.store
(
false
,
Ordering
::
SeqCst
);
})
.expect
(
"Error setting Ctrl-C handler"
);
fn
num_cuda_devices
()
->
Option
<
usize
>
{
// Check if model_id is a local model
if
let
Ok
(
cuda_visible_devices
)
=
env
::
var
(
"CUDA_VISIBLE_DEVICES"
)
{
let
local_path
=
Path
::
new
(
&
args
.model_id
);
let
n_devices
=
cuda_visible_devices
.split
(
','
)
.count
();
let
is_local_model
=
local_path
.exists
()
&&
local_path
.is_dir
();
return
Some
(
n_devices
);
// Download weights for sharded models
if
!
is_local_model
&&
args
.weights_cache_override
.is_none
()
&&
num_shard
>
1
{
download_model
(
&
args
,
running
.clone
())
?
;
}
}
None
}
#[derive(Deserialize)]
// Shared shutdown bool
#[serde(rename_all
=
"UPPERCASE"
)]
let
shutdown
=
Arc
::
new
(
Mutex
::
new
(
false
));
enum
PythonLogLevelEnum
{
// Shared shutdown channel
Trace
,
// When shutting down, the main thread will wait for all senders to be dropped
Debug
,
let
(
shutdown_sender
,
shutdown_receiver
)
=
mpsc
::
channel
();
Info
,
Success
,
Warning
,
Error
,
Critical
,
}
#[derive(Deserialize)]
// Shared channel to track shard status
struct
PythonLogLevel
{
let
(
status_sender
,
status_receiver
)
=
mpsc
::
channel
();
name
:
PythonLogLevelEnum
,
}
#[derive(Deserialize)]
spawn_shards
(
struct
PythonLogRecord
{
num_shard
,
level
:
PythonLogLevel
,
&
args
,
}
shutdown
.clone
(),
&
shutdown_receiver
,
shutdown_sender
,
&
status_receiver
,
status_sender
,
running
.clone
(),
)
?
;
#[derive(Deserialize)]
// We might have received a termination signal
struct
PythonLogMessage
{
if
!
running
.load
(
Ordering
::
SeqCst
)
{
text
:
String
,
shutdown_shards
(
shutdown
,
&
shutdown_receiver
);
re
cord
:
PythonLogRecord
,
re
turn
Ok
(());
}
}
impl
PythonLogMessage
{
let
mut
webserver
=
spawn_webserver
(
args
,
shutdown
.clone
(),
&
shutdown_receiver
)
?
;
fn
trace
(
&
self
)
{
match
self
.record.level.name
{
// Default exit code
PythonLogLevelEnum
::
Trace
=>
tracing
::
trace!
(
"{}"
,
self
.text
),
let
mut
exit_code
=
Ok
(());
PythonLogLevelEnum
::
Debug
=>
tracing
::
debug!
(
"{}"
,
self
.text
),
PythonLogLevelEnum
::
Info
=>
tracing
::
info!
(
"{}"
,
self
.text
),
while
running
.load
(
Ordering
::
SeqCst
)
{
PythonLogLevelEnum
::
Success
=>
tracing
::
info!
(
"{}"
,
self
.text
),
if
let
Ok
(
ShardStatus
::
Failed
((
rank
,
err
)))
=
status_receiver
.try_recv
()
{
PythonLogLevelEnum
::
Warning
=>
tracing
::
warn!
(
"{}"
,
self
.text
),
tracing
::
error!
(
"Shard {rank} failed:
\n
{err}"
);
PythonLogLevelEnum
::
Error
=>
tracing
::
error!
(
"{}"
,
self
.text
),
exit_code
=
Err
(
LauncherError
::
ShardFailed
);
PythonLogLevelEnum
::
Critical
=>
tracing
::
error!
(
"{}"
,
self
.text
),
break
;
}
};
match
webserver
.poll
()
{
Some
(
_
)
=>
{
tracing
::
error!
(
"Webserver Crashed"
);
shutdown_shards
(
shutdown
,
&
shutdown_receiver
);
return
Err
(
LauncherError
::
WebserverFailed
);
}
None
=>
{
sleep
(
Duration
::
from_millis
(
100
));
}
};
}
}
// Graceful termination
webserver
.terminate
()
.unwrap
();
tracing
::
info!
(
"Waiting for webserver to gracefully shutdown"
);
webserver
.wait_timeout
(
Duration
::
from_secs
(
90
))
.unwrap
();
tracing
::
info!
(
"Webserver terminated"
);
shutdown_shards
(
shutdown
,
&
shutdown_receiver
);
exit_code
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment