Commit 46f37ec0 authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

warning if launched incorrectly on multiGPU machine

parent c4447dab
......@@ -68,18 +68,29 @@ class HFLM(LM):
# multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# multigpu support with accelerate
if gpus > 1:
accelerator = Accelerator(device_placement=False)
self.gpt2 = accelerator.prepare(self.gpt2)
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.accelerator = accelerator
if gpus > self.accelerator.num_processes:
warning = ("WARNING: The number of total GPUs does not match the number of spawned processes. "
"If you would like to use data parallelism, please launch the script "
"with 'accelerate launch *script*'. "
"Current run will proceed with single device.")
print(warning)
self._rank = 0
self._world_size = 1
else:
self.gpt2 = accelerator.prepare(self.gpt2)
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
print(f"Using {gpus} GPUs with Data Parallelism")
if self.accelerator.is_local_main_process:
print(f"Using {gpus} devices data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
assert gpus == self.accelerator.num_processes, "Number of GPUs does not match the world size. If evaluating with data parallelism, please call script with accelerate launch *script name*"
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
@property
def eot_token_id(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment