Commit 14efeaa5 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Check NaN during training loop.

PiperOrigin-RevId: 279116219
parent 98db9b25
...@@ -22,6 +22,7 @@ from __future__ import print_function ...@@ -22,6 +22,7 @@ from __future__ import print_function
import json import json
import os import os
import numpy as np
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -512,6 +513,8 @@ class DistributedExecutor(object): ...@@ -512,6 +513,8 @@ class DistributedExecutor(object):
train_loss) train_loss)
if not isinstance(train_loss, dict): if not isinstance(train_loss, dict):
train_loss = {'total_loss': train_loss} train_loss = {'total_loss': train_loss}
if np.isnan(train_loss['total_loss']):
raise ValueError('total loss is NaN.')
if train_metric: if train_metric:
train_metric_result = train_metric.result() train_metric_result = train_metric.result()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment