Commit 14efeaa5 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Check NaN during training loop.

PiperOrigin-RevId: 279116219
parent 98db9b25
......@@ -22,6 +22,7 @@ from __future__ import print_function
import json
import os
import numpy as np
from absl import flags
from absl import logging
import tensorflow as tf
......@@ -512,6 +513,8 @@ class DistributedExecutor(object):
train_loss)
if not isinstance(train_loss, dict):
train_loss = {'total_loss': train_loss}
if np.isnan(train_loss['total_loss']):
raise ValueError('total loss is NaN.')
if train_metric:
train_metric_result = train_metric.result()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment