Unverified Commit fa28535d authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Add stdev to the Dense layer. (#7189)

parent 13feba3c
......@@ -26,6 +26,7 @@ from __future__ import print_function
import functools
import tensorflow as tf
from tensorflow.python.keras import backend
from tensorflow.python.keras import initializers
from tensorflow.python.keras import layers
from tensorflow.python.keras import regularizers
......@@ -241,9 +242,11 @@ def resnet(num_blocks, classes=10, training=None):
rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3]
x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x)
x = layers.Dense(classes, activation='softmax',
# kernel_initializer='he_normal',
x = layers.Dense(classes,
activation='softmax',
kernel_initializer=initializers.RandomNormal(stddev=0.01),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='fc10')(x)
inputs = img_input
......
......@@ -27,14 +27,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import warnings
from tensorflow.python.keras import backend
from tensorflow.python.keras import initializers
from tensorflow.python.keras import layers
from tensorflow.python.keras import models
from tensorflow.python.keras import regularizers
from tensorflow.python.keras import utils
L2_WEIGHT_DECAY = 1e-4
......@@ -45,15 +42,14 @@ BATCH_NORM_EPSILON = 1e-5
def identity_block(input_tensor, kernel_size, filters, stage, block):
"""The identity block is the block that has no conv layer at shortcut.
# Arguments
Args:
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
kernel_size: default 3, the kernel size of middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
# Returns
Returns:
Output tensor for the block.
"""
filters1, filters2, filters3 = filters
......@@ -107,21 +103,20 @@ def conv_block(input_tensor,
strides=(2, 2)):
"""A block that has a conv layer at shortcut.
# Arguments
Note that from stage 3,
the second conv layer at main path is with strides=(2, 2)
And the shortcut should have strides=(2, 2) as well
Args:
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
kernel_size: default 3, the kernel size of middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the second conv layer in the block.
# Returns
Returns:
Output tensor for the block.
Note that from stage 3,
the second conv layer at main path is with strides=(2, 2)
And the shortcut should have strides=(2, 2) as well
"""
filters1, filters2, filters3 = filters
if backend.image_data_format() == 'channels_last':
......@@ -175,11 +170,12 @@ def conv_block(input_tensor,
def resnet50(num_classes, dtype='float32', batch_size=None):
# TODO(tfboyd): add training argument, just lik resnet56.
"""Instantiates the ResNet50 architecture.
Args:
num_classes: `int` number of classes for image classification.
dtype: dtype to use float32 or float16 are most common.
batch_size: Size of the batches for each step.
Returns:
A Keras model instance.
......@@ -234,9 +230,11 @@ def resnet50(num_classes, dtype='float32', batch_size=None):
x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x)
x = layers.Dense(
num_classes,
kernel_initializer=initializers.RandomNormal(stddev=0.01),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='fc1000')(x)
# TODO(reedwm): Remove manual casts once mixed precision can be enabled with a
# single line of code.
x = backend.cast(x, 'float32')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment