Unverified Commit 9b049266 authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #3 from tensorflow/master

Updated
parents 63af6ba5 c5ad244e
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
from official.modeling import performance
from official.staging.training import standard_runnable from official.staging.training import standard_runnable
from official.staging.training import utils from official.staging.training import utils
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
...@@ -85,21 +86,15 @@ class ResnetRunnable(standard_runnable.StandardTrainable, ...@@ -85,21 +86,15 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
# Make sure iterations variable is created inside scope. # Make sure iterations variable is created inside scope.
self.global_step = self.optimizer.iterations self.global_step = self.optimizer.iterations
if self.dtype == tf.float16: use_graph_rewrite = flags_obj.fp16_implementation == 'graph_rewrite'
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) if use_graph_rewrite and not flags_obj.use_tf_function:
self.optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
self.optimizer, loss_scale))
elif flags_obj.fp16_implementation == 'graph_rewrite':
# `dtype` is still float32 in this case. We built the graph in float32
# and let the graph rewrite change parts of it float16.
if not flags_obj.use_tf_function:
raise ValueError('--fp16_implementation=graph_rewrite requires ' raise ValueError('--fp16_implementation=graph_rewrite requires '
'--use_tf_function to be true') '--use_tf_function to be true')
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) self.optimizer = performance.configure_optimizer(
self.optimizer = ( self.optimizer,
tf.train.experimental.enable_mixed_precision_graph_rewrite( use_float16=self.dtype == tf.float16,
self.optimizer, loss_scale)) use_graph_rewrite=use_graph_rewrite,
loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
......
...@@ -5,6 +5,10 @@ This folder contains machine learning models implemented by researchers in ...@@ -5,6 +5,10 @@ This folder contains machine learning models implemented by researchers in
respective authors. To propose a model for inclusion, please submit a pull respective authors. To propose a model for inclusion, please submit a pull
request. request.
**Note: some research models are stale and have not updated to the latest
TensorFlow yet. If users have trouble with TF 2.x for research models,
please consider TF 1.15.**
## Models ## Models
- [adversarial_crypto](adversarial_crypto): protecting communications with - [adversarial_crypto](adversarial_crypto): protecting communications with
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment