"vscode:/vscode.git/clone" did not exist on "19d64f2b725889cfbdb000937a2d57c07db5cfa8"
Commit 5a68ac62 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 312626203
parent 0ff8d37a
...@@ -32,6 +32,8 @@ import time ...@@ -32,6 +32,8 @@ import time
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from absl import app as absl_app from absl import app as absl_app
from absl import flags from absl import flags
from six.moves import range
from six.moves import zip
import tensorflow as tf import tensorflow as tf
from tensorflow.python import eager as tfe from tensorflow.python import eager as tfe
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
...@@ -75,7 +77,7 @@ def train(model, optimizer, dataset, step_counter, log_interval=None): ...@@ -75,7 +77,7 @@ def train(model, optimizer, dataset, step_counter, log_interval=None):
compute_accuracy(logits, labels)) compute_accuracy(logits, labels))
grads = tape.gradient(loss_value, model.variables) grads = tape.gradient(loss_value, model.variables)
optimizer.apply_gradients( optimizer.apply_gradients(
zip(grads, model.variables), global_step=step_counter) list(zip(grads, model.variables)), global_step=step_counter)
if log_interval and batch % log_interval == 0: if log_interval and batch % log_interval == 0:
rate = log_interval / (time.time() - start) rate = log_interval / (time.time() - start)
print('Step #%d\tLoss: %.6f (%d steps/sec)' % (batch, loss_value, rate)) print('Step #%d\tLoss: %.6f (%d steps/sec)' % (batch, loss_value, rate))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment