Commit e4f1dfb6 authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

improved warning message for processes < total gpus

parent 80fe11d8
...@@ -72,13 +72,13 @@ class HFLM(LM): ...@@ -72,13 +72,13 @@ class HFLM(LM):
if gpus > 1: if gpus > 1:
accelerator = Accelerator(device_placement=False) accelerator = Accelerator(device_placement=False)
if gpus > accelerator.num_processes: if gpus > accelerator.num_processes:
warning = ("WARNING: The number of total GPUs does not match the number of spawned processes. " warning = ("WARNING: The number of total system GPUs does not match the number of spawned processes. "
"If you would like to use data parallelism, please launch the script " "If you would like to use data parallelism, please launch the script "
"with 'accelerate launch *script*'. " "with 'accelerate launch *script*'. "
"Current run will proceed with single device.") f"Current run will proceed with {accelerator.num_processes} devices.")
print(warning) print(warning)
self._rank = 0 self._rank = self.accelerator.local_process_index
self._world_size = 1 self._world_size = self.accelerator.num_processes
else: else:
self.gpt2 = accelerator.prepare(self.gpt2) self.gpt2 = accelerator.prepare(self.gpt2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment