Commit 85a6db17 authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 368573031
parent d0879611
...@@ -28,18 +28,14 @@ from __future__ import division ...@@ -28,18 +28,14 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import backend
from tensorflow.python.keras import initializers
from tensorflow.python.keras import models
from tensorflow.python.keras import regularizers
from official.vision.image_classification.resnet import imagenet_preprocessing from official.vision.image_classification.resnet import imagenet_preprocessing
layers = tf.keras.layers layers = tf.keras.layers
def _gen_l2_regularizer(use_l2_regularizer=True, l2_weight_decay=1e-4): def _gen_l2_regularizer(use_l2_regularizer=True, l2_weight_decay=1e-4):
return regularizers.l2(l2_weight_decay) if use_l2_regularizer else None return tf.keras.regularizers.L2(
l2_weight_decay) if use_l2_regularizer else None
def identity_block(input_tensor, def identity_block(input_tensor,
...@@ -66,7 +62,7 @@ def identity_block(input_tensor, ...@@ -66,7 +62,7 @@ def identity_block(input_tensor,
Output tensor for the block. Output tensor for the block.
""" """
filters1, filters2, filters3 = filters filters1, filters2, filters3 = filters
if backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = 3 bn_axis = 3
else: else:
bn_axis = 1 bn_axis = 1
...@@ -154,7 +150,7 @@ def conv_block(input_tensor, ...@@ -154,7 +150,7 @@ def conv_block(input_tensor,
Output tensor for the block. Output tensor for the block.
""" """
filters1, filters2, filters3 = filters filters1, filters2, filters3 = filters
if backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = 3 bn_axis = 3
else: else:
bn_axis = 1 bn_axis = 1
...@@ -253,7 +249,7 @@ def resnet50(num_classes, ...@@ -253,7 +249,7 @@ def resnet50(num_classes,
# Hub image modules expect inputs in the range [0, 1]. This rescales these # Hub image modules expect inputs in the range [0, 1]. This rescales these
# inputs to the range expected by the trained model. # inputs to the range expected by the trained model.
x = layers.Lambda( x = layers.Lambda(
lambda x: x * 255.0 - backend.constant( lambda x: x * 255.0 - tf.keras.backend.constant( # pylint: disable=g-long-lambda
imagenet_preprocessing.CHANNEL_MEANS, imagenet_preprocessing.CHANNEL_MEANS,
shape=[1, 1, 3], shape=[1, 1, 3],
dtype=x.dtype), dtype=x.dtype),
...@@ -262,7 +258,7 @@ def resnet50(num_classes, ...@@ -262,7 +258,7 @@ def resnet50(num_classes,
else: else:
x = img_input x = img_input
if backend.image_data_format() == 'channels_first': if tf.keras.backend.image_data_format() == 'channels_first':
x = layers.Permute((3, 1, 2))(x) x = layers.Permute((3, 1, 2))(x)
bn_axis = 1 bn_axis = 1
else: # channels_last else: # channels_last
...@@ -315,7 +311,8 @@ def resnet50(num_classes, ...@@ -315,7 +311,8 @@ def resnet50(num_classes,
x = layers.GlobalAveragePooling2D()(x) x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense( x = layers.Dense(
num_classes, num_classes,
kernel_initializer=initializers.RandomNormal(stddev=0.01), kernel_initializer=tf.compat.v1.keras.initializers.random_normal(
stddev=0.01),
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer), kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
bias_regularizer=_gen_l2_regularizer(use_l2_regularizer), bias_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='fc1000')( name='fc1000')(
...@@ -326,4 +323,4 @@ def resnet50(num_classes, ...@@ -326,4 +323,4 @@ def resnet50(num_classes,
x = layers.Activation('softmax', dtype='float32')(x) x = layers.Activation('softmax', dtype='float32')(x)
# Create model. # Create model.
return models.Model(img_input, x, name='resnet50') return tf.keras.Model(img_input, x, name='resnet50')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment