Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
26c96542
"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "62dd95870c812d87418e53229eb3fdee95c8a067"
Commit
26c96542
authored
Mar 29, 2018
by
DefineFC
Browse files
prepare data, train, and eval on py3
parent
65f4d60b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
17 additions
and
10 deletions
+17
-10
research/deeplab/datasets/build_cityscapes_data.py
research/deeplab/datasets/build_cityscapes_data.py
+2
-2
research/deeplab/datasets/build_data.py
research/deeplab/datasets/build_data.py
+4
-1
research/deeplab/datasets/build_voc2012_data.py
research/deeplab/datasets/build_voc2012_data.py
+2
-2
research/deeplab/eval.py
research/deeplab/eval.py
+3
-2
research/deeplab/train.py
research/deeplab/train.py
+3
-2
research/deeplab/utils/train_utils.py
research/deeplab/utils/train_utils.py
+3
-1
No files found.
research/deeplab/datasets/build_cityscapes_data.py
View file @
26c96542
...
@@ -154,10 +154,10 @@ def _convert_dataset(dataset_split):
...
@@ -154,10 +154,10 @@ def _convert_dataset(dataset_split):
i
+
1
,
num_images
,
shard_id
))
i
+
1
,
num_images
,
shard_id
))
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
# Read the image.
# Read the image.
image_data
=
tf
.
gfile
.
FastGFile
(
image_files
[
i
],
'r'
).
read
()
image_data
=
tf
.
gfile
.
FastGFile
(
image_files
[
i
],
'r
b
'
).
read
()
height
,
width
=
image_reader
.
read_image_dims
(
image_data
)
height
,
width
=
image_reader
.
read_image_dims
(
image_data
)
# Read the semantic segmentation annotation.
# Read the semantic segmentation annotation.
seg_data
=
tf
.
gfile
.
FastGFile
(
label_files
[
i
],
'r'
).
read
()
seg_data
=
tf
.
gfile
.
FastGFile
(
label_files
[
i
],
'r
b
'
).
read
()
seg_height
,
seg_width
=
label_reader
.
read_image_dims
(
seg_data
)
seg_height
,
seg_width
=
label_reader
.
read_image_dims
(
seg_data
)
if
height
!=
seg_height
or
width
!=
seg_width
:
if
height
!=
seg_height
or
width
!=
seg_width
:
raise
RuntimeError
(
'Shape mismatched between image and label.'
)
raise
RuntimeError
(
'Shape mismatched between image and label.'
)
...
...
research/deeplab/datasets/build_data.py
View file @
26c96542
...
@@ -125,7 +125,10 @@ def _bytes_list_feature(values):
...
@@ -125,7 +125,10 @@ def _bytes_list_feature(values):
Returns:
Returns:
A TF-Feature.
A TF-Feature.
"""
"""
return
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
values
]))
def
norm2bytes
(
value
):
return
value
.
encode
()
if
isinstance
(
value
,
str
)
else
value
return
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
norm2bytes
(
values
)]))
def
image_seg_to_tfexample
(
image_data
,
filename
,
height
,
width
,
seg_data
):
def
image_seg_to_tfexample
(
image_data
,
filename
,
height
,
width
,
seg_data
):
...
...
research/deeplab/datasets/build_voc2012_data.py
View file @
26c96542
...
@@ -114,13 +114,13 @@ def _convert_dataset(dataset_split):
...
@@ -114,13 +114,13 @@ def _convert_dataset(dataset_split):
# Read the image.
# Read the image.
image_filename
=
os
.
path
.
join
(
image_filename
=
os
.
path
.
join
(
FLAGS
.
image_folder
,
filenames
[
i
]
+
'.'
+
FLAGS
.
image_format
)
FLAGS
.
image_folder
,
filenames
[
i
]
+
'.'
+
FLAGS
.
image_format
)
image_data
=
tf
.
gfile
.
FastGFile
(
image_filename
,
'r'
).
read
()
image_data
=
tf
.
gfile
.
FastGFile
(
image_filename
,
'r
b
'
).
read
()
height
,
width
=
image_reader
.
read_image_dims
(
image_data
)
height
,
width
=
image_reader
.
read_image_dims
(
image_data
)
# Read the semantic segmentation annotation.
# Read the semantic segmentation annotation.
seg_filename
=
os
.
path
.
join
(
seg_filename
=
os
.
path
.
join
(
FLAGS
.
semantic_segmentation_folder
,
FLAGS
.
semantic_segmentation_folder
,
filenames
[
i
]
+
'.'
+
FLAGS
.
label_format
)
filenames
[
i
]
+
'.'
+
FLAGS
.
label_format
)
seg_data
=
tf
.
gfile
.
FastGFile
(
seg_filename
,
'r'
).
read
()
seg_data
=
tf
.
gfile
.
FastGFile
(
seg_filename
,
'r
b
'
).
read
()
seg_height
,
seg_width
=
label_reader
.
read_image_dims
(
seg_data
)
seg_height
,
seg_width
=
label_reader
.
read_image_dims
(
seg_data
)
if
height
!=
seg_height
or
width
!=
seg_width
:
if
height
!=
seg_height
or
width
!=
seg_width
:
raise
RuntimeError
(
'Shape mismatched between image and label.'
)
raise
RuntimeError
(
'Shape mismatched between image and label.'
)
...
...
research/deeplab/eval.py
View file @
26c96542
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
See model.py for more details and usage.
See model.py for more details and usage.
"""
"""
import
six
import
math
import
math
import
tensorflow
as
tf
import
tensorflow
as
tf
from
deeplab
import
common
from
deeplab
import
common
...
@@ -144,7 +145,7 @@ def main(unused_argv):
...
@@ -144,7 +145,7 @@ def main(unused_argv):
metrics_to_values
,
metrics_to_updates
=
(
metrics_to_values
,
metrics_to_updates
=
(
tf
.
contrib
.
metrics
.
aggregate_metric_map
(
metric_map
))
tf
.
contrib
.
metrics
.
aggregate_metric_map
(
metric_map
))
for
metric_name
,
metric_value
in
metrics_to_values
.
iteritems
(
):
for
metric_name
,
metric_value
in
six
.
iteritems
(
metrics_to_values
):
slim
.
summaries
.
add_scalar_summary
(
slim
.
summaries
.
add_scalar_summary
(
metric_value
,
metric_name
,
print_summary
=
True
)
metric_value
,
metric_name
,
print_summary
=
True
)
...
@@ -163,7 +164,7 @@ def main(unused_argv):
...
@@ -163,7 +164,7 @@ def main(unused_argv):
checkpoint_dir
=
FLAGS
.
checkpoint_dir
,
checkpoint_dir
=
FLAGS
.
checkpoint_dir
,
logdir
=
FLAGS
.
eval_logdir
,
logdir
=
FLAGS
.
eval_logdir
,
num_evals
=
num_batches
,
num_evals
=
num_batches
,
eval_op
=
metrics_to_updates
.
values
(),
eval_op
=
list
(
metrics_to_updates
.
values
()
)
,
max_number_of_evaluations
=
num_eval_iters
,
max_number_of_evaluations
=
num_eval_iters
,
eval_interval_secs
=
FLAGS
.
eval_interval_secs
)
eval_interval_secs
=
FLAGS
.
eval_interval_secs
)
...
...
research/deeplab/train.py
View file @
26c96542
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
See model.py for more details and usage.
See model.py for more details and usage.
"""
"""
import
six
import
tensorflow
as
tf
import
tensorflow
as
tf
from
deeplab
import
common
from
deeplab
import
common
from
deeplab
import
model
from
deeplab
import
model
...
@@ -190,7 +191,7 @@ def _build_deeplab(inputs_queue, outputs_to_num_classes, ignore_label):
...
@@ -190,7 +191,7 @@ def _build_deeplab(inputs_queue, outputs_to_num_classes, ignore_label):
is_training
=
True
,
is_training
=
True
,
fine_tune_batch_norm
=
FLAGS
.
fine_tune_batch_norm
)
fine_tune_batch_norm
=
FLAGS
.
fine_tune_batch_norm
)
for
output
,
num_classes
in
outputs_to_num_classes
.
iteritems
(
):
for
output
,
num_classes
in
six
.
iteritems
(
outputs_to_num_classes
):
train_utils
.
add_softmax_cross_entropy_loss_for_each_scale
(
train_utils
.
add_softmax_cross_entropy_loss_for_each_scale
(
outputs_to_scales_to_logits
[
output
],
outputs_to_scales_to_logits
[
output
],
samples
[
common
.
LABEL
],
samples
[
common
.
LABEL
],
...
@@ -217,7 +218,7 @@ def main(unused_argv):
...
@@ -217,7 +218,7 @@ def main(unused_argv):
assert
FLAGS
.
train_batch_size
%
config
.
num_clones
==
0
,
(
assert
FLAGS
.
train_batch_size
%
config
.
num_clones
==
0
,
(
'Training batch size not divisble by number of clones (GPUs).'
)
'Training batch size not divisble by number of clones (GPUs).'
)
clone_batch_size
=
FLAGS
.
train_batch_size
/
config
.
num_clones
clone_batch_size
=
int
(
FLAGS
.
train_batch_size
/
config
.
num_clones
)
# Get dataset-dependent information.
# Get dataset-dependent information.
dataset
=
segmentation_dataset
.
get_dataset
(
dataset
=
segmentation_dataset
.
get_dataset
(
...
...
research/deeplab/utils/train_utils.py
View file @
26c96542
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
# ==============================================================================
# ==============================================================================
"""Utility functions for training."""
"""Utility functions for training."""
import
six
import
tensorflow
as
tf
import
tensorflow
as
tf
slim
=
tf
.
contrib
.
slim
slim
=
tf
.
contrib
.
slim
...
@@ -44,7 +46,7 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
...
@@ -44,7 +46,7 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
if
labels
is
None
:
if
labels
is
None
:
raise
ValueError
(
'No label for softmax cross entropy loss.'
)
raise
ValueError
(
'No label for softmax cross entropy loss.'
)
for
scale
,
logits
in
scales_to_logits
.
iteritems
(
):
for
scale
,
logits
in
six
.
iteritems
(
scales_to_logits
):
loss_scope
=
None
loss_scope
=
None
if
scope
:
if
scope
:
loss_scope
=
'%s_%s'
%
(
scope
,
scale
)
loss_scope
=
'%s_%s'
%
(
scope
,
scale
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment