Commit 85a6db17 authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 368573031
parent d0879611
......@@ -28,18 +28,14 @@ from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras import backend
from tensorflow.python.keras import initializers
from tensorflow.python.keras import models
from tensorflow.python.keras import regularizers
from official.vision.image_classification.resnet import imagenet_preprocessing
layers = tf.keras.layers
def _gen_l2_regularizer(use_l2_regularizer=True, l2_weight_decay=1e-4):
return regularizers.l2(l2_weight_decay) if use_l2_regularizer else None
return tf.keras.regularizers.L2(
l2_weight_decay) if use_l2_regularizer else None
def identity_block(input_tensor,
......@@ -66,7 +62,7 @@ def identity_block(input_tensor,
Output tensor for the block.
"""
filters1, filters2, filters3 = filters
if backend.image_data_format() == 'channels_last':
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
......@@ -154,7 +150,7 @@ def conv_block(input_tensor,
Output tensor for the block.
"""
filters1, filters2, filters3 = filters
if backend.image_data_format() == 'channels_last':
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
......@@ -253,7 +249,7 @@ def resnet50(num_classes,
# Hub image modules expect inputs in the range [0, 1]. This rescales these
# inputs to the range expected by the trained model.
x = layers.Lambda(
lambda x: x * 255.0 - backend.constant(
lambda x: x * 255.0 - tf.keras.backend.constant( # pylint: disable=g-long-lambda
imagenet_preprocessing.CHANNEL_MEANS,
shape=[1, 1, 3],
dtype=x.dtype),
......@@ -262,7 +258,7 @@ def resnet50(num_classes,
else:
x = img_input
if backend.image_data_format() == 'channels_first':
if tf.keras.backend.image_data_format() == 'channels_first':
x = layers.Permute((3, 1, 2))(x)
bn_axis = 1
else: # channels_last
......@@ -315,7 +311,8 @@ def resnet50(num_classes,
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(
num_classes,
kernel_initializer=initializers.RandomNormal(stddev=0.01),
kernel_initializer=tf.compat.v1.keras.initializers.random_normal(
stddev=0.01),
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
bias_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='fc1000')(
......@@ -326,4 +323,4 @@ def resnet50(num_classes,
x = layers.Activation('softmax', dtype='float32')(x)
# Create model.
return models.Model(img_input, x, name='resnet50')
return tf.keras.Model(img_input, x, name='resnet50')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment