"git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "3f8d00df7834cf26cdd8562879d8f83d72a06348"
Unverified Commit 5fd2dcb5 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(launcher): default num_shard to CUDA_VISIBLE_DEVICES if possible (#108)

parent 0ac38d33
...@@ -115,13 +115,11 @@ fn main() -> ExitCode { ...@@ -115,13 +115,11 @@ fn main() -> ExitCode {
None => { None => {
// try to default to the number of available GPUs // try to default to the number of available GPUs
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES"); tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES");
let cuda_visible_devices = env::var("CUDA_VISIBLE_DEVICES") let n_devices = num_cuda_devices()
.expect("--num-shard and CUDA_VISIBLE_DEVICES are not set"); .expect("--num-shard and CUDA_VISIBLE_DEVICES are not set");
let n_devices = cuda_visible_devices.split(",").count();
if n_devices <= 1 { if n_devices <= 1 {
panic!("`sharded` is true but only found {n_devices} CUDA devices"); panic!("`sharded` is true but only found {n_devices} CUDA devices");
} }
tracing::info!("Sharding on {n_devices} found CUDA devices");
n_devices n_devices
} }
Some(num_shard) => { Some(num_shard) => {
...@@ -144,9 +142,19 @@ fn main() -> ExitCode { ...@@ -144,9 +142,19 @@ fn main() -> ExitCode {
} }
} }
} else { } else {
// default to a single shard match num_shard {
num_shard.unwrap_or(1) // get num_shard from CUDA_VISIBLE_DEVICES or default to a single shard
None => num_cuda_devices().unwrap_or(1),
Some(num_shard) => num_shard,
}
}; };
if num_shard < 1 {
panic!("`num_shard` cannot be < 1");
}
if num_shard > 1 {
tracing::info!("Sharding model on {num_shard} processes");
}
// Signal handler // Signal handler
let running = Arc::new(AtomicBool::new(true)); let running = Arc::new(AtomicBool::new(true));
...@@ -669,3 +677,11 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive ...@@ -669,3 +677,11 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive
// This will block till all shutdown_sender are dropped // This will block till all shutdown_sender are dropped
let _ = shutdown_receiver.recv(); let _ = shutdown_receiver.recv();
} }
fn num_cuda_devices() -> Option<usize> {
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
let n_devices = cuda_visible_devices.split(',').count();
return Some(n_devices);
}
None
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment