Unverified Commit 1e2ceffd authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #4 from tensorflow/master

Updating 
parents 51e60bab c7adbbe4
......@@ -31,7 +31,6 @@ import tensorflow as tf
from tensorflow.python.keras import backend
from tensorflow.python.keras import initializers
from tensorflow.python.keras import layers as tf_python_keras_layers
from tensorflow.python.keras import models
from tensorflow.python.keras import regularizers
from official.vision.image_classification import imagenet_preprocessing
......@@ -40,30 +39,7 @@ L2_WEIGHT_DECAY = 1e-4
BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5
layers = tf_python_keras_layers
def change_keras_layer(use_tf_keras_layers=False):
"""Change layers to either tf.keras.layers or tf.python.keras.layers.
Layer version of tf.keras.layers is depends on tensorflow version, but
tf.python.keras.layers checks environment variable TF2_BEHAVIOR.
This function is a temporal function to use tf.keras.layers.
Currently, tf v2 batchnorm layer is slower than tf v1 batchnorm layer.
this function is useful for tracking benchmark result for each version.
This function will be removed when we use tf.keras.layers as default.
TODO(b/146939027): Remove this function when tf v2 batchnorm reaches training
speed parity with tf v1 batchnorm.
Args:
use_tf_keras_layers: whether to use tf.keras.layers.
"""
global layers
if use_tf_keras_layers:
layers = tf.keras.layers
else:
layers = tf_python_keras_layers
layers = tf.keras.layers
def _gen_l2_regularizer(use_l2_regularizer=True):
......
......@@ -24,10 +24,10 @@ import os
from absl import app
from absl import flags
import tensorflow as tf
import tensorflow.compat.v2 as tf
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_model
from official.vision.image_classification.resnet import resnet_model
FLAGS = flags.FLAGS
......
......@@ -36,7 +36,7 @@ from official.utils.misc import keras_utils
from official.utils.misc import model_helpers
from official.vision.image_classification import common
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_model
from official.vision.image_classification.resnet import resnet_model
def run(flags_obj):
......@@ -126,7 +126,6 @@ def run(flags_obj):
is_training=True,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
use_keras_image_data_format=use_keras_image_data_format),
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
......@@ -142,7 +141,6 @@ def run(flags_obj):
is_training=False,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
use_keras_image_data_format=use_keras_image_data_format),
dtype=dtype,
......@@ -185,7 +183,6 @@ def run(flags_obj):
model = trivial_model.trivial_model(
imagenet_preprocessing.NUM_CLASSES)
elif flags_obj.model == 'resnet50_v1.5':
resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers)
model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES)
elif flags_obj.model == 'mobilenet':
......
......@@ -50,7 +50,6 @@ class KerasImagenetTest(tf.test.TestCase):
"-model", "resnet50_v1.5",
"-optimizer", "resnet50_default",
"-pruning_method", "polynomial_decay",
"-use_tf_keras_layers", "true",
],
"mobilenet": [
"-model", "mobilenet",
......
......@@ -27,7 +27,7 @@ from official.staging.training import utils
from official.utils.flags import core as flags_core
from official.vision.image_classification import common
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_model
from official.vision.image_classification.resnet import resnet_model
class ResnetRunnable(standard_runnable.StandardTrainable,
......@@ -70,7 +70,6 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
else:
self.input_fn = imagenet_preprocessing.input_fn
resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers)
self.model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES,
batch_size=flags_obj.batch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment