Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
565c3fa3
Unverified
Commit
565c3fa3
authored
Feb 08, 2018
by
Neal Wu
Committed by
GitHub
Feb 08, 2018
Browse files
Merge pull request #3343 from tensorflow/resnet-num-classes
Allow users to pass in num_classes to ResNet
parents
7cb653fd
75c04257
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
9 deletions
+50
-9
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+10
-2
official/resnet/cifar10_test.py
official/resnet/cifar10_test.py
+19
-5
official/resnet/imagenet_main.py
official/resnet/imagenet_main.py
+10
-2
official/resnet/imagenet_test.py
official/resnet/imagenet_test.py
+11
-0
No files found.
official/resnet/cifar10_main.py
View file @
565c3fa3
...
...
@@ -129,8 +129,16 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
# Running the model
###############################################################################
class
Cifar10Model
(
resnet
.
Model
):
def
__init__
(
self
,
resnet_size
,
data_format
=
None
):
def
__init__
(
self
,
resnet_size
,
data_format
=
None
,
num_classes
=
_NUM_CLASSES
):
"""These are the parameters that work for CIFAR-10 data.
Args:
resnet_size: The number of convolutional layers needed in the model.
data_format: Either 'channels_first' or 'channels_last', specifying which
data format to use when setting up the model.
num_classes: The number of output classes needed from the model. This
enables users to extend the same model to their own datasets.
"""
if
resnet_size
%
6
!=
2
:
raise
ValueError
(
'resnet_size must be 6n + 2:'
,
resnet_size
)
...
...
@@ -139,7 +147,7 @@ class Cifar10Model(resnet.Model):
super
(
Cifar10Model
,
self
).
__init__
(
resnet_size
=
resnet_size
,
num_classes
=
_NUM_CLASSES
,
num_classes
=
num_classes
,
num_filters
=
16
,
kernel_size
=
3
,
conv_stride
=
1
,
...
...
official/resnet/cifar10_test.py
View file @
565c3fa3
...
...
@@ -27,6 +27,9 @@ import cifar10_main
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
ERROR
)
_BATCH_SIZE
=
128
_HEIGHT
=
32
_WIDTH
=
32
_NUM_CHANNELS
=
3
class
BaseTest
(
tf
.
test
.
TestCase
):
...
...
@@ -34,8 +37,8 @@ class BaseTest(tf.test.TestCase):
def
test_dataset_input_fn
(
self
):
fake_data
=
bytearray
()
fake_data
.
append
(
7
)
for
i
in
range
(
3
):
for
_
in
range
(
1024
):
for
i
in
range
(
_NUM_CHANNELS
):
for
_
in
range
(
_HEIGHT
*
_WIDTH
):
fake_data
.
append
(
i
)
_
,
filename
=
mkstemp
(
dir
=
self
.
get_temp_dir
())
...
...
@@ -49,8 +52,8 @@ class BaseTest(tf.test.TestCase):
lambda
val
:
cifar10_main
.
parse_record
(
val
,
False
))
image
,
label
=
fake_dataset
.
make_one_shot_iterator
().
get_next
()
self
.
assertEqual
(
label
.
get_
shape
().
as_list
()
,
[
10
]
)
self
.
assertEqual
(
image
.
get_
shape
().
as_list
(),
[
32
,
32
,
3
]
)
self
.
assert
All
Equal
(
label
.
shape
,
(
10
,)
)
self
.
assert
All
Equal
(
image
.
shape
,
(
_HEIGHT
,
_WIDTH
,
_NUM_CHANNELS
)
)
with
self
.
test_session
()
as
sess
:
image
,
label
=
sess
.
run
([
image
,
label
])
...
...
@@ -62,7 +65,7 @@ class BaseTest(tf.test.TestCase):
self
.
assertAllClose
(
pixel
,
np
.
array
([
-
1.225
,
0.
,
1.225
]),
rtol
=
1e-3
)
def
input_fn
(
self
):
features
=
tf
.
random_uniform
([
_BATCH_SIZE
,
32
,
32
,
3
])
features
=
tf
.
random_uniform
([
_BATCH_SIZE
,
_HEIGHT
,
_WIDTH
,
_NUM_CHANNELS
])
labels
=
tf
.
random_uniform
(
[
_BATCH_SIZE
],
maxval
=
9
,
dtype
=
tf
.
int32
)
return
features
,
tf
.
one_hot
(
labels
,
10
)
...
...
@@ -104,6 +107,17 @@ class BaseTest(tf.test.TestCase):
def
test_cifar10_model_fn_predict_mode
(
self
):
self
.
cifar10_model_fn_helper
(
tf
.
estimator
.
ModeKeys
.
PREDICT
)
def
test_cifar10model_shape
(
self
):
batch_size
=
135
num_classes
=
246
model
=
cifar10_main
.
Cifar10Model
(
32
,
data_format
=
'channels_last'
,
num_classes
=
num_classes
)
fake_input
=
tf
.
random_uniform
([
batch_size
,
_HEIGHT
,
_WIDTH
,
_NUM_CHANNELS
])
output
=
model
(
fake_input
,
training
=
True
)
self
.
assertAllEqual
(
output
.
shape
,
(
batch_size
,
num_classes
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/resnet/imagenet_main.py
View file @
565c3fa3
...
...
@@ -132,8 +132,16 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
# Running the model
###############################################################################
class
ImagenetModel
(
resnet
.
Model
):
def
__init__
(
self
,
resnet_size
,
data_format
=
None
):
def
__init__
(
self
,
resnet_size
,
data_format
=
None
,
num_classes
=
_NUM_CLASSES
):
"""These are the parameters that work for Imagenet data.
Args:
resnet_size: The number of convolutional layers needed in the model.
data_format: Either 'channels_first' or 'channels_last', specifying which
data format to use when setting up the model.
num_classes: The number of output classes needed from the model. This
enables users to extend the same model to their own datasets.
"""
# For bigger models, we want to use "bottleneck" layers
...
...
@@ -146,7 +154,7 @@ class ImagenetModel(resnet.Model):
super
(
ImagenetModel
,
self
).
__init__
(
resnet_size
=
resnet_size
,
num_classes
=
_NUM_CLASSES
,
num_classes
=
num_classes
,
num_filters
=
64
,
kernel_size
=
7
,
conv_stride
=
2
,
...
...
official/resnet/imagenet_test.py
View file @
565c3fa3
...
...
@@ -176,6 +176,17 @@ class BaseTest(tf.test.TestCase):
def
test_resnet_model_fn_predict_mode
(
self
):
self
.
resnet_model_fn_helper
(
tf
.
estimator
.
ModeKeys
.
PREDICT
)
def
test_imagenetmodel_shape
(
self
):
batch_size
=
135
num_classes
=
246
model
=
imagenet_main
.
ImagenetModel
(
50
,
data_format
=
'channels_last'
,
num_classes
=
num_classes
)
fake_input
=
tf
.
random_uniform
([
batch_size
,
224
,
224
,
3
])
output
=
model
(
fake_input
,
training
=
True
)
self
.
assertAllEqual
(
output
.
shape
,
(
batch_size
,
num_classes
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment