Unverified Commit 9b049266 authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #3 from tensorflow/master

Updated
parents 63af6ba5 c5ad244e
......@@ -20,6 +20,7 @@ from __future__ import print_function
import tensorflow.compat.v2 as tf
from official.modeling import performance
from official.staging.training import standard_runnable
from official.staging.training import utils
from official.utils.flags import core as flags_core
......@@ -85,21 +86,15 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
# Make sure iterations variable is created inside scope.
self.global_step = self.optimizer.iterations
if self.dtype == tf.float16:
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
self.optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
self.optimizer, loss_scale))
elif flags_obj.fp16_implementation == 'graph_rewrite':
# `dtype` is still float32 in this case. We built the graph in float32
# and let the graph rewrite change parts of it float16.
if not flags_obj.use_tf_function:
raise ValueError('--fp16_implementation=graph_rewrite requires '
'--use_tf_function to be true')
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
self.optimizer = (
tf.train.experimental.enable_mixed_precision_graph_rewrite(
self.optimizer, loss_scale))
use_graph_rewrite = flags_obj.fp16_implementation == 'graph_rewrite'
if use_graph_rewrite and not flags_obj.use_tf_function:
raise ValueError('--fp16_implementation=graph_rewrite requires '
'--use_tf_function to be true')
self.optimizer = performance.configure_optimizer(
self.optimizer,
use_float16=self.dtype == tf.float16,
use_graph_rewrite=use_graph_rewrite,
loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
......
......@@ -5,6 +5,10 @@ This folder contains machine learning models implemented by researchers in
respective authors. To propose a model for inclusion, please submit a pull
request.
**Note: some research models are stale and have not updated to the latest
TensorFlow yet. If users have trouble with TF 2.x for research models,
please consider TF 1.15.**
## Models
- [adversarial_crypto](adversarial_crypto): protecting communications with
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment