Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4ec3452d
Commit
4ec3452d
authored
Jan 26, 2017
by
Yaroslav Bulatov
Browse files
another xrange change + change to concat_v2
parent
10340bf5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
11 deletions
+3
-11
resnet/cifar_input.py
resnet/cifar_input.py
+1
-10
resnet/resnet_main.py
resnet/resnet_main.py
+2
-1
No files found.
resnet/cifar_input.py
View file @
4ec3452d
...
...
@@ -18,15 +18,6 @@
import
tensorflow
as
tf
# backward compatible concat (arg order changed in head)
import
inspect
def
concat
(
values
,
axis
):
if
'axis'
in
inspect
.
signature
(
tf
.
concat
).
parameters
.
keys
():
return
tf
.
concat
(
values
=
values
,
axis
=
axis
)
else
:
assert
'concat_dim'
in
inspect
.
signature
(
tf
.
concat
).
parameters
.
keys
()
return
tf
.
concat
(
concat_dim
=
axis
,
values
=
values
)
def
build_input
(
dataset
,
data_path
,
batch_size
,
mode
):
"""Build CIFAR image and labels.
...
...
@@ -109,7 +100,7 @@ def build_input(dataset, data_path, batch_size, mode):
labels
=
tf
.
reshape
(
labels
,
[
batch_size
,
1
])
indices
=
tf
.
reshape
(
tf
.
range
(
0
,
batch_size
,
1
),
[
batch_size
,
1
])
labels
=
tf
.
sparse_to_dense
(
tf
.
concat
(
values
=
[
indices
,
labels
],
axis
=
1
),
tf
.
concat
_v2
(
values
=
[
indices
,
labels
],
axis
=
1
),
[
batch_size
,
num_classes
],
1.0
,
0.0
)
assert
len
(
images
.
get_shape
())
==
4
...
...
resnet/resnet_main.py
View file @
4ec3452d
...
...
@@ -16,6 +16,7 @@
"""ResNet Train/Eval module.
"""
import
time
import
six
import
sys
import
cifar_input
...
...
@@ -140,7 +141,7 @@ def evaluate(hps):
saver
.
restore
(
sess
,
ckpt_state
.
model_checkpoint_path
)
total_prediction
,
correct_prediction
=
0
,
0
for
_
in
x
range
(
FLAGS
.
eval_batch_count
):
for
_
in
six
.
moves
.
range
(
FLAGS
.
eval_batch_count
):
(
summaries
,
loss
,
predictions
,
truth
,
train_step
)
=
sess
.
run
(
[
model
.
summaries
,
model
.
cost
,
model
.
predictions
,
model
.
labels
,
model
.
global_step
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment