Commit 0a7720e9 authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

multi-device: take minimum computed autobatch over all ranks

parent f4df3e48
......@@ -354,8 +354,19 @@ class HFLM(LM):
return batch_size
batch_size = forward_batch()
utils.clear_torch_cache()
if self.world_size > 1:
# if multi-GPU, always take minimum over all selected batch sizes
max_rnk_bs = torch.tensor([batch_size], device=self.device)
gathered = (
self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
)
batch_size = min(gathered)
utils.clear_torch_cache()
return batch_size
utils.clear_torch_cache()
return batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment