Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c868da8b
Commit
c868da8b
authored
Dec 25, 2018
by
Shining Sun
Browse files
bug fixes
parent
25efe03e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
9 additions
and
49 deletions
+9
-49
official/resnet/keras/keras_cifar_main.py
official/resnet/keras/keras_cifar_main.py
+1
-2
official/resnet/keras/keras_common.py
official/resnet/keras/keras_common.py
+1
-1
official/resnet/keras/resnet_cifar_model.py
official/resnet/keras/resnet_cifar_model.py
+6
-46
official/resnet/keras/resnet_model.py
official/resnet/keras/resnet_model.py
+1
-0
No files found.
official/resnet/keras/keras_cifar_main.py
View file @
c868da8b
...
@@ -136,8 +136,7 @@ def run(flags_obj):
...
@@ -136,8 +136,7 @@ def run(flags_obj):
strategy
=
distribution_utils
.
get_distribution_strategy
(
strategy
=
distribution_utils
.
get_distribution_strategy
(
flags_obj
.
num_gpus
,
flags_obj
.
turn_off_distribution_strategy
)
flags_obj
.
num_gpus
,
flags_obj
.
turn_off_distribution_strategy
)
model
=
resnet_cifar_model
.
resnet56
(
input_shape
=
(
32
,
32
,
3
),
model
=
resnet_cifar_model
.
resnet56
(
classes
=
cifar_main
.
NUM_CLASSES
)
classes
=
cifar_main
.
NUM_CLASSES
)
model
.
compile
(
loss
=
'categorical_crossentropy'
,
model
.
compile
(
loss
=
'categorical_crossentropy'
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
...
...
official/resnet/keras/keras_common.py
View file @
c868da8b
...
@@ -163,7 +163,7 @@ def define_keras_flags():
...
@@ -163,7 +163,7 @@ def define_keras_flags():
flags
.
DEFINE_integer
(
flags
.
DEFINE_integer
(
name
=
'train_steps'
,
default
=
None
,
name
=
'train_steps'
,
default
=
None
,
help
=
'The number of steps to run for training. If it is larger than '
help
=
'The number of steps to run for training. If it is larger than '
'# batches per epoch, then use # bathes per epoch. When this flag is '
'# batches per epoch, then use # bat
c
hes per epoch. When this flag is '
'set, only one epoch is going to run for training.'
)
'set, only one epoch is going to run for training.'
)
...
...
official/resnet/keras/resnet_cifar_model.py
View file @
c868da8b
...
@@ -33,42 +33,6 @@ BATCH_NORM_EPSILON = 1e-5
...
@@ -33,42 +33,6 @@ BATCH_NORM_EPSILON = 1e-5
L2_WEIGHT_DECAY
=
2e-4
L2_WEIGHT_DECAY
=
2e-4
def
_obtain_input_shape
(
input_shape
,
default_size
,
data_format
):
"""Internal utility to compute/validate a model's input shape.
Arguments:
input_shape: Either None (will return the default network input shape),
or a user-provided shape to be validated.
default_size: Default input width/height for the model.
data_format: Image data format to use.
Returns:
An integer shape tuple (may include None entries).
Raises:
ValueError: In case of invalid argument values.
"""
if
input_shape
and
len
(
input_shape
)
==
3
:
if
data_format
==
'channels_first'
:
if
input_shape
[
0
]
not
in
{
1
,
3
}:
warnings
.
warn
(
'This model usually expects 1 or 3 input channels. '
'However, it was passed an input_shape with '
+
str
(
input_shape
[
0
])
+
' input channels.'
)
default_shape
=
(
input_shape
[
0
],
default_size
,
default_size
)
else
:
if
input_shape
[
-
1
]
not
in
{
1
,
3
}:
warnings
.
warn
(
'This model usually expects 1 or 3 input channels. '
'However, it was passed an input_shape with '
+
str
(
input_shape
[
-
1
])
+
' input channels.'
)
default_shape
=
(
default_size
,
default_size
,
input_shape
[
-
1
])
return
input_shape
def
identity_building_block
(
input_tensor
,
def
identity_building_block
(
input_tensor
,
kernel_size
,
kernel_size
,
filters
,
filters
,
...
@@ -212,7 +176,7 @@ def conv_building_block(input_tensor,
...
@@ -212,7 +176,7 @@ def conv_building_block(input_tensor,
return
x
return
x
def
resnet56
(
input_shape
=
None
,
classes
=
100
,
training
=
None
):
def
resnet56
(
classes
=
100
,
training
=
None
):
"""Instantiates the ResNet56 architecture.
"""Instantiates the ResNet56 architecture.
Arguments:
Arguments:
...
@@ -225,16 +189,12 @@ def resnet56(input_shape=None, classes=100, training=None):
...
@@ -225,16 +189,12 @@ def resnet56(input_shape=None, classes=100, training=None):
A Keras model instance.
A Keras model instance.
"""
"""
# Determine proper input shape
# Determine proper input shape
input_shape
=
_obtain_input_shape
(
if
backend
.
image_data_format
()
==
'channels_first'
:
input_shape
,
input_shape
=
(
3
,
32
,
32
)
default_size
=
32
,
data_format
=
tf
.
keras
.
backend
.
image_data_format
())
img_input
=
tf
.
keras
.
layers
.
Input
(
shape
=
input_shape
)
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
3
else
:
bn_axis
=
1
bn_axis
=
1
else
:
# channel_last
input_shape
=
(
32
,
32
,
3
)
bn_axis
=
3
x
=
tf
.
keras
.
layers
.
ZeroPadding2D
(
padding
=
(
1
,
1
),
name
=
'conv1_pad'
)(
img_input
)
x
=
tf
.
keras
.
layers
.
ZeroPadding2D
(
padding
=
(
1
,
1
),
name
=
'conv1_pad'
)(
img_input
)
x
=
tf
.
keras
.
layers
.
Conv2D
(
16
,
(
3
,
3
),
x
=
tf
.
keras
.
layers
.
Conv2D
(
16
,
(
3
,
3
),
...
...
official/resnet/keras/resnet_model.py
View file @
c868da8b
...
@@ -181,6 +181,7 @@ def conv_block(input_tensor,
...
@@ -181,6 +181,7 @@ def conv_block(input_tensor,
def
resnet50
(
num_classes
):
def
resnet50
(
num_classes
):
# TODO(tfboyd): add training argument, just lik resnet56.
"""Instantiates the ResNet50 architecture.
"""Instantiates the ResNet50 architecture.
Args:
Args:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment