Commit a442244d authored by Halil Akin's avatar Halil Akin Committed by Facebook Github Bot
Browse files

Protect against failures in case of OOMs

Summary: Fixing some distributed failures that happen when OOMs are observed.

Reviewed By: myleott

Differential Revision: D13121054

fbshipit-source-id: f71a0a695332acbaa1797e89887b8b7c7ddaa727
parent 693894b6
...@@ -75,14 +75,15 @@ class FairseqOptimizer(object): ...@@ -75,14 +75,15 @@ class FairseqOptimizer(object):
def multiply_grads(self, c): def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``.""" """Multiplies grads by a constant ``c``."""
for p in self.params: for p in self.params:
p.grad.data.mul_(c) if p.grad is not None:
p.grad.data.mul_(c)
def clip_grad_norm(self, max_norm): def clip_grad_norm(self, max_norm):
"""Clips gradient norm.""" """Clips gradient norm."""
if max_norm > 0: if max_norm > 0:
return torch.nn.utils.clip_grad_norm_(self.params, max_norm) return torch.nn.utils.clip_grad_norm_(self.params, max_norm)
else: else:
return math.sqrt(sum(p.grad.data.norm()**2 for p in self.params)) return math.sqrt(sum(p.grad.data.norm()**2 for p in self.params if p.grad is not None))
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step.""" """Performs a single optimization step."""
......
...@@ -210,7 +210,7 @@ class Trainer(object): ...@@ -210,7 +210,7 @@ class Trainer(object):
sample_sizes = list(chain.from_iterable(sample_sizes)) sample_sizes = list(chain.from_iterable(sample_sizes))
ooms = sum(ooms) ooms = sum(ooms)
if ooms == self.args.distributed_world_size: if ooms == self.args.distributed_world_size * len(samples):
print('| WARNING: OOM in all workers, skipping update') print('| WARNING: OOM in all workers, skipping update')
self.zero_grad() self.zero_grad()
return None return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment