"driver/include/conv_common.hpp" did not exist on "8d4607403e42ae0289efc98a392c6964501bc9cf"
Unverified Commit c77fa461 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge branch 'multigpu-feature' into multigpu-feature-minor-edits

parents bc103ce2 650d3c76
......@@ -72,7 +72,6 @@ class HFLM(LM):
# multigpu support with accelerate
if gpus > 1:
# accelerator = Accelerator(device_placement=False)
accelerator = Accelerator()
if gpus > accelerator.num_processes:
warning = (
......@@ -82,9 +81,9 @@ class HFLM(LM):
f"Current run will proceed with {accelerator.num_processes} devices."
)
print(warning)
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes
else:
self.gpt2 = accelerator.prepare(self.gpt2)
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment