Commit 1de7e4a5 authored by Benjamin Fattori's avatar Benjamin Fattori Committed by lintangsutawika
Browse files

accelerator multidevice setup

parent 2af4f9e0
......@@ -71,7 +71,27 @@ class Seq2SeqHFLM(LM):
self.batch_size_per_gpu = batch_size
if gpus > 1:
raise NotImplementedError
accelerator = Accelerator()
if gpus > accelerator.num_processes:
warning = (
"WARNING: The number of total system GPUs does not match the number of spawned processes. "
"If you would like to use data parallelism, please launch the script "
"with 'accelerate launch *script*'. "
f"Current run will proceed with {accelerator.num_processes} devices."
)
print(warning)
self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes
else:
self.model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
print(f"Using {gpus} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
@property
def eot_token_id(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment