Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
5fd2dcb5
Unverified
Commit
5fd2dcb5
authored
Mar 08, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 08, 2023
Browse files
feat(launcher): default num_shard to CUDA_VISIBLE_DEVICES if possible (#108)
parent
0ac38d33
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
5 deletions
+21
-5
launcher/src/main.rs
launcher/src/main.rs
+21
-5
No files found.
launcher/src/main.rs
View file @
5fd2dcb5
...
@@ -115,13 +115,11 @@ fn main() -> ExitCode {
...
@@ -115,13 +115,11 @@ fn main() -> ExitCode {
None
=>
{
None
=>
{
// try to default to the number of available GPUs
// try to default to the number of available GPUs
tracing
::
info!
(
"Parsing num_shard from CUDA_VISIBLE_DEVICES"
);
tracing
::
info!
(
"Parsing num_shard from CUDA_VISIBLE_DEVICES"
);
let
cuda_visible_devices
=
env
::
var
(
"CUDA_VISIBLE_DEVICES"
)
let
n_devices
=
num_cuda_devices
(
)
.expect
(
"--num-shard and CUDA_VISIBLE_DEVICES are not set"
);
.expect
(
"--num-shard and CUDA_VISIBLE_DEVICES are not set"
);
let
n_devices
=
cuda_visible_devices
.split
(
","
)
.count
();
if
n_devices
<=
1
{
if
n_devices
<=
1
{
panic!
(
"`sharded` is true but only found {n_devices} CUDA devices"
);
panic!
(
"`sharded` is true but only found {n_devices} CUDA devices"
);
}
}
tracing
::
info!
(
"Sharding on {n_devices} found CUDA devices"
);
n_devices
n_devices
}
}
Some
(
num_shard
)
=>
{
Some
(
num_shard
)
=>
{
...
@@ -144,9 +142,19 @@ fn main() -> ExitCode {
...
@@ -144,9 +142,19 @@ fn main() -> ExitCode {
}
}
}
}
}
else
{
}
else
{
// default to a single shard
match
num_shard
{
num_shard
.unwrap_or
(
1
)
// get num_shard from CUDA_VISIBLE_DEVICES or default to a single shard
None
=>
num_cuda_devices
()
.unwrap_or
(
1
),
Some
(
num_shard
)
=>
num_shard
,
}
};
};
if
num_shard
<
1
{
panic!
(
"`num_shard` cannot be < 1"
);
}
if
num_shard
>
1
{
tracing
::
info!
(
"Sharding model on {num_shard} processes"
);
}
// Signal handler
// Signal handler
let
running
=
Arc
::
new
(
AtomicBool
::
new
(
true
));
let
running
=
Arc
::
new
(
AtomicBool
::
new
(
true
));
...
@@ -669,3 +677,11 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive
...
@@ -669,3 +677,11 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive
// This will block till all shutdown_sender are dropped
// This will block till all shutdown_sender are dropped
let
_
=
shutdown_receiver
.recv
();
let
_
=
shutdown_receiver
.recv
();
}
}
fn
num_cuda_devices
()
->
Option
<
usize
>
{
if
let
Ok
(
cuda_visible_devices
)
=
env
::
var
(
"CUDA_VISIBLE_DEVICES"
)
{
let
n_devices
=
cuda_visible_devices
.split
(
','
)
.count
();
return
Some
(
n_devices
);
}
None
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment