Commit 4a30a5f6 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix inconsistent gradient check

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/692

Differential Revision: D15174954

fbshipit-source-id: 1a7bff9aeed3e2cc658577be9d79e8c9f72314c2
parent ffc9c8cc
...@@ -11,6 +11,7 @@ Train a network across multiple GPUs. ...@@ -11,6 +11,7 @@ Train a network across multiple GPUs.
from collections import OrderedDict from collections import OrderedDict
from itertools import chain from itertools import chain
import math
import os import os
import torch import torch
...@@ -239,8 +240,10 @@ class Trainer(object): ...@@ -239,8 +240,10 @@ class Trainer(object):
logging_outputs = list(chain.from_iterable(logging_outputs)) logging_outputs = list(chain.from_iterable(logging_outputs))
sample_sizes = list(chain.from_iterable(sample_sizes)) sample_sizes = list(chain.from_iterable(sample_sizes))
ooms = sum(ooms) ooms = sum(ooms)
assert all(norm == prev_norms[0] for norm in prev_norms), \ assert (
'Fatal error: gradients are inconsistent between workers' all(norm == prev_norms[0] for norm in prev_norms)
or all(math.isnan(norm) or math.isinf(norm) for norm in prev_norms)
), 'Fatal error: gradients are inconsistent between workers'
self.meters['oom'].update(ooms, len(samples)) self.meters['oom'].update(ooms, len(samples))
if ooms == self.args.distributed_world_size * len(samples): if ooms == self.args.distributed_world_size * len(samples):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment