"...resnet50_tensorflow.git" did not exist on "88d844e7b15bb3b7140f101bc59e8c2a3b5cdf19"
Unverified Commit 9b049266 authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #3 from tensorflow/master

Updated
parents 63af6ba5 c5ad244e
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
from official.modeling import performance
from official.staging.training import standard_runnable from official.staging.training import standard_runnable
from official.staging.training import utils from official.staging.training import utils
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
...@@ -85,21 +86,15 @@ class ResnetRunnable(standard_runnable.StandardTrainable, ...@@ -85,21 +86,15 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
# Make sure iterations variable is created inside scope. # Make sure iterations variable is created inside scope.
self.global_step = self.optimizer.iterations self.global_step = self.optimizer.iterations
if self.dtype == tf.float16: use_graph_rewrite = flags_obj.fp16_implementation == 'graph_rewrite'
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) if use_graph_rewrite and not flags_obj.use_tf_function:
self.optimizer = ( raise ValueError('--fp16_implementation=graph_rewrite requires '
tf.keras.mixed_precision.experimental.LossScaleOptimizer( '--use_tf_function to be true')
self.optimizer, loss_scale)) self.optimizer = performance.configure_optimizer(
elif flags_obj.fp16_implementation == 'graph_rewrite': self.optimizer,
# `dtype` is still float32 in this case. We built the graph in float32 use_float16=self.dtype == tf.float16,
# and let the graph rewrite change parts of it float16. use_graph_rewrite=use_graph_rewrite,
if not flags_obj.use_tf_function: loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
raise ValueError('--fp16_implementation=graph_rewrite requires '
'--use_tf_function to be true')
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
self.optimizer = (
tf.train.experimental.enable_mixed_precision_graph_rewrite(
self.optimizer, loss_scale))
self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
......
...@@ -5,6 +5,10 @@ This folder contains machine learning models implemented by researchers in ...@@ -5,6 +5,10 @@ This folder contains machine learning models implemented by researchers in
respective authors. To propose a model for inclusion, please submit a pull respective authors. To propose a model for inclusion, please submit a pull
request. request.
**Note: some research models are stale and have not updated to the latest
TensorFlow yet. If users have trouble with TF 2.x for research models,
please consider TF 1.15.**
## Models ## Models
- [adversarial_crypto](adversarial_crypto): protecting communications with - [adversarial_crypto](adversarial_crypto): protecting communications with
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment