"web/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "e4b2ccfb230e8f2d964d8d08533a348c7e839278"
Commit 8ff61153 authored by Naurril's avatar Naurril Committed by Taylor Robie
Browse files

remove final_size parameter of resnet (#5326)

parent 630c4ca8
...@@ -184,7 +184,6 @@ class Cifar10Model(resnet_model.Model): ...@@ -184,7 +184,6 @@ class Cifar10Model(resnet_model.Model):
first_pool_stride=None, first_pool_stride=None,
block_sizes=[num_blocks] * 3, block_sizes=[num_blocks] * 3,
block_strides=[1, 2, 2], block_strides=[1, 2, 2],
final_size=64,
resnet_version=resnet_version, resnet_version=resnet_version,
data_format=data_format, data_format=data_format,
dtype=dtype dtype=dtype
......
...@@ -232,10 +232,8 @@ class ImagenetModel(resnet_model.Model): ...@@ -232,10 +232,8 @@ class ImagenetModel(resnet_model.Model):
# For bigger models, we want to use "bottleneck" layers # For bigger models, we want to use "bottleneck" layers
if resnet_size < 50: if resnet_size < 50:
bottleneck = False bottleneck = False
final_size = 512
else: else:
bottleneck = True bottleneck = True
final_size = 2048
super(ImagenetModel, self).__init__( super(ImagenetModel, self).__init__(
resnet_size=resnet_size, resnet_size=resnet_size,
...@@ -248,7 +246,6 @@ class ImagenetModel(resnet_model.Model): ...@@ -248,7 +246,6 @@ class ImagenetModel(resnet_model.Model):
first_pool_stride=2, first_pool_stride=2,
block_sizes=_get_block_sizes(resnet_size), block_sizes=_get_block_sizes(resnet_size),
block_strides=[1, 2, 2, 2], block_strides=[1, 2, 2, 2],
final_size=final_size,
resnet_version=resnet_version, resnet_version=resnet_version,
data_format=data_format, data_format=data_format,
dtype=dtype dtype=dtype
......
...@@ -354,7 +354,7 @@ class Model(object): ...@@ -354,7 +354,7 @@ class Model(object):
kernel_size, kernel_size,
conv_stride, first_pool_size, first_pool_stride, conv_stride, first_pool_size, first_pool_stride,
block_sizes, block_strides, block_sizes, block_strides,
final_size, resnet_version=DEFAULT_VERSION, data_format=None, resnet_version=DEFAULT_VERSION, data_format=None,
dtype=DEFAULT_DTYPE): dtype=DEFAULT_DTYPE):
"""Creates a model for classifying an image. """Creates a model for classifying an image.
...@@ -376,7 +376,6 @@ class Model(object): ...@@ -376,7 +376,6 @@ class Model(object):
i-th set. i-th set.
block_strides: List of integers representing the desired stride size for block_strides: List of integers representing the desired stride size for
each of the sets of block layers. Should be same length as block_sizes. each of the sets of block layers. Should be same length as block_sizes.
final_size: The expected size of the model after the second pooling.
resnet_version: Integer representing which version of the ResNet network resnet_version: Integer representing which version of the ResNet network
to use. See README for details. Valid values: [1, 2] to use. See README for details. Valid values: [1, 2]
data_format: Input format ('channels_last', 'channels_first', or None). data_format: Input format ('channels_last', 'channels_first', or None).
...@@ -422,7 +421,6 @@ class Model(object): ...@@ -422,7 +421,6 @@ class Model(object):
self.first_pool_stride = first_pool_stride self.first_pool_stride = first_pool_stride
self.block_sizes = block_sizes self.block_sizes = block_sizes
self.block_strides = block_strides self.block_strides = block_strides
self.final_size = final_size
self.dtype = dtype self.dtype = dtype
self.pre_activation = resnet_version == 2 self.pre_activation = resnet_version == 2
...@@ -542,7 +540,7 @@ class Model(object): ...@@ -542,7 +540,7 @@ class Model(object):
inputs = tf.reduce_mean(inputs, axes, keepdims=True) inputs = tf.reduce_mean(inputs, axes, keepdims=True)
inputs = tf.identity(inputs, 'final_reduce_mean') inputs = tf.identity(inputs, 'final_reduce_mean')
inputs = tf.reshape(inputs, [-1, self.final_size]) inputs = tf.squeeze(inputs, axes)
inputs = tf.layers.dense(inputs=inputs, units=self.num_classes) inputs = tf.layers.dense(inputs=inputs, units=self.num_classes)
inputs = tf.identity(inputs, 'final_dense') inputs = tf.identity(inputs, 'final_dense')
return inputs return inputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment