Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
be9b8025
Unverified
Commit
be9b8025
authored
Feb 21, 2018
by
Jonathan Huang
Committed by
GitHub
Feb 21, 2018
Browse files
Merge pull request #3380 from pkulzc/master
Internal changes for object detection.
parents
d3143cbc
c173234f
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
122 additions
and
66 deletions
+122
-66
research/object_detection/builders/anchor_generator_builder_test.py
...bject_detection/builders/anchor_generator_builder_test.py
+0
-1
research/object_detection/builders/dataset_builder.py
research/object_detection/builders/dataset_builder.py
+5
-7
research/object_detection/builders/model_builder.py
research/object_detection/builders/model_builder.py
+2
-0
research/object_detection/builders/model_builder_test.py
research/object_detection/builders/model_builder_test.py
+1
-0
research/object_detection/data_decoders/tf_example_decoder.py
...arch/object_detection/data_decoders/tf_example_decoder.py
+17
-7
research/object_detection/g3doc/installation.md
research/object_detection/g3doc/installation.md
+20
-0
research/object_detection/meta_architectures/ssd_meta_arch.py
...arch/object_detection/meta_architectures/ssd_meta_arch.py
+10
-1
research/object_detection/meta_architectures/ssd_meta_arch_test.py
...object_detection/meta_architectures/ssd_meta_arch_test.py
+5
-5
research/object_detection/protos/input_reader.proto
research/object_detection/protos/input_reader.proto
+7
-4
research/object_detection/protos/ssd.proto
research/object_detection/protos/ssd.proto
+4
-0
research/object_detection/samples/configs/faster_rcnn_resnet101_kitti.config
...ection/samples/configs/faster_rcnn_resnet101_kitti.config
+0
-1
research/object_detection/samples/configs/mask_rcnn_inception_resnet_v2_atrous_coco.config
.../configs/mask_rcnn_inception_resnet_v2_atrous_coco.config
+4
-0
research/object_detection/samples/configs/mask_rcnn_inception_v2_coco.config
...ection/samples/configs/mask_rcnn_inception_v2_coco.config
+4
-0
research/object_detection/samples/configs/mask_rcnn_resnet101_atrous_coco.config
...on/samples/configs/mask_rcnn_resnet101_atrous_coco.config
+4
-0
research/object_detection/samples/configs/mask_rcnn_resnet101_pets.config
...detection/samples/configs/mask_rcnn_resnet101_pets.config
+2
-0
research/object_detection/samples/configs/mask_rcnn_resnet50_atrous_coco.config
...ion/samples/configs/mask_rcnn_resnet50_atrous_coco.config
+4
-0
research/object_detection/train.py
research/object_detection/train.py
+1
-3
research/object_detection/utils/BUILD
research/object_detection/utils/BUILD
+1
-0
research/object_detection/utils/dataset_util.py
research/object_detection/utils/dataset_util.py
+12
-24
research/object_detection/utils/learning_schedules.py
research/object_detection/utils/learning_schedules.py
+19
-13
No files found.
research/object_detection/builders/anchor_generator_builder_test.py
View file @
be9b8025
...
...
@@ -266,7 +266,6 @@ class AnchorGeneratorBuilderTest(tf.test.TestCase):
self
.
assertTrue
(
isinstance
(
anchor_generator_object
,
multiscale_grid_anchor_generator
.
MultiscaleGridAnchorGenerator
))
print
anchor_generator_object
.
_anchor_grid_info
for
level
,
anchor_grid_info
in
zip
(
range
(
3
,
8
),
anchor_generator_object
.
_anchor_grid_info
):
self
.
assertEqual
(
set
(
anchor_grid_info
.
keys
()),
set
([
'level'
,
'info'
]))
...
...
research/object_detection/builders/dataset_builder.py
View file @
be9b8025
...
...
@@ -21,7 +21,7 @@ Note: If users wishes to also use their own InputReaders with the Object
Detection configuration framework, they should define their own builder function
that wraps the build function.
"""
import
functools
import
tensorflow
as
tf
from
object_detection.core
import
standard_fields
as
fields
...
...
@@ -86,8 +86,8 @@ def _get_padding_shapes(dataset, max_num_boxes, num_classes,
for
tensor_key
,
_
in
dataset
.
output_shapes
.
items
()}
def
build
(
input_reader_config
,
transform_input_data_fn
=
None
,
num_workers
=
1
,
worker_index
=
0
,
batch_size
=
1
,
max_num_boxes
=
None
,
num_classes
=
None
,
def
build
(
input_reader_config
,
transform_input_data_fn
=
None
,
batch_size
=
1
,
max_num_boxes
=
None
,
num_classes
=
None
,
spatial_image_shape
=
None
):
"""Builds a tf.data.Dataset.
...
...
@@ -100,8 +100,6 @@ def build(input_reader_config, transform_input_data_fn=None, num_workers=1,
input_reader_config: A input_reader_pb2.InputReader object.
transform_input_data_fn: Function to apply to all records, or None if
no extra decoding is required.
num_workers: Number of workers (tpu shard).
worker_index: Id for the current worker (tpu shard).
batch_size: Batch size. If not None, returns a padded batch dataset.
max_num_boxes: Max number of groundtruth boxes needed to computes shapes for
padding. This is only used if batch_size is greater than 1.
...
...
@@ -146,8 +144,8 @@ def build(input_reader_config, transform_input_data_fn=None, num_workers=1,
return
processed
dataset
=
dataset_util
.
read_dataset
(
tf
.
data
.
TFRecordDataset
,
process_fn
,
config
.
input_path
[:]
,
input_reader_config
,
num_workers
,
worker_index
)
functools
.
partial
(
tf
.
data
.
TFRecordDataset
,
buffer_size
=
8
*
1000
*
1000
)
,
process_fn
,
config
.
input_path
[:],
input_reader_config
)
if
batch_size
>
1
:
if
num_classes
is
None
:
...
...
research/object_detection/builders/model_builder.py
View file @
be9b8025
...
...
@@ -152,6 +152,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
matcher
=
matcher_builder
.
build
(
ssd_config
.
matcher
)
region_similarity_calculator
=
sim_calc
.
build
(
ssd_config
.
similarity_calculator
)
encode_background_as_zeros
=
ssd_config
.
encode_background_as_zeros
ssd_box_predictor
=
box_predictor_builder
.
build
(
hyperparams_builder
.
build
,
ssd_config
.
box_predictor
,
is_training
,
num_classes
)
...
...
@@ -173,6 +174,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
feature_extractor
,
matcher
,
region_similarity_calculator
,
encode_background_as_zeros
,
image_resizer_fn
,
non_max_suppression_fn
,
score_conversion_fn
,
...
...
research/object_detection/builders/model_builder_test.py
View file @
be9b8025
...
...
@@ -237,6 +237,7 @@ class ModelBuilderTest(tf.test.TestCase):
iou_similarity {
}
}
encode_background_as_zeros: true
anchor_generator {
multiscale_anchor_generator {
aspect_ratios: [1.0, 2.0, 0.5]
...
...
research/object_detection/data_decoders/tf_example_decoder.py
View file @
be9b8025
...
...
@@ -35,7 +35,8 @@ class TfExampleDecoder(data_decoder.DataDecoder):
load_instance_masks
=
False
,
instance_mask_type
=
input_reader_pb2
.
NUMERICAL_MASKS
,
label_map_proto_file
=
None
,
use_display_name
=
False
):
use_display_name
=
False
,
dct_method
=
''
):
"""Constructor sets keys_to_features and items_to_handlers.
Args:
...
...
@@ -50,6 +51,11 @@ class TfExampleDecoder(data_decoder.DataDecoder):
use_display_name: whether or not to use the `display_name` for label
mapping (instead of `name`). Only used if label_map_proto_file is
provided.
dct_method: An optional string. Defaults to None. It only takes
effect when image format is jpeg, used to specify a hint about the
algorithm used for jpeg decompression. Currently valid values
are ['INTEGER_FAST', 'INTEGER_ACCURATE']. The hint may be ignored, for
example, the jpeg library does not have that specific option.
Raises:
ValueError: If `instance_mask_type` option is not one of
...
...
@@ -96,8 +102,12 @@ class TfExampleDecoder(data_decoder.DataDecoder):
tf
.
VarLenFeature
(
tf
.
float32
),
}
self
.
items_to_handlers
=
{
fields
.
InputDataFields
.
image
:
slim_example_decoder
.
Image
(
image_key
=
'image/encoded'
,
format_key
=
'image/format'
,
channels
=
3
),
fields
.
InputDataFields
.
image
:
slim_example_decoder
.
Image
(
image_key
=
'image/encoded'
,
format_key
=
'image/format'
,
channels
=
3
,
dct_method
=
dct_method
),
fields
.
InputDataFields
.
source_id
:
(
slim_example_decoder
.
Tensor
(
'image/source_id'
)),
fields
.
InputDataFields
.
key
:
(
...
...
@@ -106,10 +116,10 @@ class TfExampleDecoder(data_decoder.DataDecoder):
slim_example_decoder
.
Tensor
(
'image/filename'
)),
# Object boxes and classes.
fields
.
InputDataFields
.
groundtruth_boxes
:
(
slim_example_decoder
.
BoundingBox
(
[
'ymin'
,
'xmin'
,
'ymax'
,
'xmax'
],
'image/object/bbox/'
)),
fields
.
InputDataFields
.
groundtruth_area
:
slim_example_decoder
.
Tensor
(
'image/object/area'
),
slim_example_decoder
.
BoundingBox
(
[
'ymin'
,
'xmin'
,
'ymax'
,
'xmax'
],
'image/object/bbox/'
)),
fields
.
InputDataFields
.
groundtruth_area
:
slim_example_decoder
.
Tensor
(
'image/object/area'
),
fields
.
InputDataFields
.
groundtruth_is_crowd
:
(
slim_example_decoder
.
Tensor
(
'image/object/is_crowd'
)),
fields
.
InputDataFields
.
groundtruth_difficult
:
(
...
...
research/object_detection/g3doc/installation.md
View file @
be9b8025
...
...
@@ -12,6 +12,7 @@ Tensorflow Object Detection API depends on the following libraries:
*
Jupyter notebook
*
Matplotlib
*
Tensorflow
*
cocoapi
For detailed steps to install Tensorflow, follow the
[
Tensorflow installation
instructions
](
https://www.tensorflow.org/install/
)
. A typical user can install
...
...
@@ -41,6 +42,25 @@ sudo pip install jupyter
sudo
pip
install
matplotlib
```
## COCO API installation
Download the
<a
href=
"https://github.com/cocodataset/cocoapi"
target=
_blank
>
cocoapi
</a>
and
copy the pycocotools subfolder to the tensorflow/models/research directory if
you are interested in using COCO evaluation metrics. The default metrics are
based on those used in Pascal VOC evaluation. To use the COCO object detection
metrics add
`metrics_set: "coco_detection_metrics"`
to the
`eval_config`
message
in the config file. To use the COCO instance segmentation metrics add
`metrics_set: "coco_mask_metrics"`
to the
`eval_config`
message in the config
file.
```
bash
git clone https://github.com/cocodataset/cocoapi.git
cd
cocoapi/PythonAPI
make
cp
-r
pycocotools <path_to_tensorflow>/models/research/
```
## Protobuf Compilation
The Tensorflow Object Detection API uses Protobufs to configure model and
...
...
research/object_detection/meta_architectures/ssd_meta_arch.py
View file @
be9b8025
...
...
@@ -121,6 +121,7 @@ class SSDMetaArch(model.DetectionModel):
feature_extractor
,
matcher
,
region_similarity_calculator
,
encode_background_as_zeros
,
image_resizer_fn
,
non_max_suppression_fn
,
score_conversion_fn
,
...
...
@@ -147,6 +148,9 @@ class SSDMetaArch(model.DetectionModel):
matcher: a matcher.Matcher object.
region_similarity_calculator: a
region_similarity_calculator.RegionSimilarityCalculator object.
encode_background_as_zeros: boolean determining whether background
targets are to be encoded as an all zeros vector or a one-hot
vector (where background is the 0th class).
image_resizer_fn: a callable for image resizing. This callable always
takes a rank-3 image tensor (corresponding to a single image) and
returns a rank-3 image tensor, possibly with new spatial dimensions and
...
...
@@ -190,7 +194,12 @@ class SSDMetaArch(model.DetectionModel):
# TODO: handle agnostic mode and positive/negative class
# weights
unmatched_cls_target
=
None
unmatched_cls_target
=
tf
.
constant
([
1
]
+
self
.
num_classes
*
[
0
],
tf
.
float32
)
unmatched_cls_target
=
tf
.
constant
([
1
]
+
self
.
num_classes
*
[
0
],
tf
.
float32
)
if
encode_background_as_zeros
:
unmatched_cls_target
=
tf
.
constant
((
self
.
num_classes
+
1
)
*
[
0
],
tf
.
float32
)
self
.
_target_assigner
=
target_assigner
.
TargetAssigner
(
self
.
_region_similarity_calculator
,
self
.
_matcher
,
...
...
research/object_detection/meta_architectures/ssd_meta_arch_test.py
View file @
be9b8025
...
...
@@ -84,7 +84,7 @@ class SsdMetaArchTest(test_case.TestCase):
fake_feature_extractor
=
FakeSSDFeatureExtractor
()
mock_matcher
=
test_utils
.
MockMatcher
()
region_similarity_calculator
=
sim_calc
.
IouSimilarity
()
encode_background_as_zeros
=
False
def
image_resizer_fn
(
image
):
return
[
tf
.
identity
(
image
),
tf
.
shape
(
image
)]
...
...
@@ -111,10 +111,10 @@ class SsdMetaArchTest(test_case.TestCase):
model
=
ssd_meta_arch
.
SSDMetaArch
(
is_training
,
mock_anchor_generator
,
mock_box_predictor
,
mock_box_coder
,
fake_feature_extractor
,
mock_matcher
,
region_similarity_calculator
,
image_resizer_fn
,
non_max_suppression_fn
,
tf
.
identity
,
classification_loss
,
localization_loss
,
classification_loss_weight
,
localiz
ation_loss_weight
,
normalize_loss_by_num_matches
,
hard_example_miner
,
add_summaries
=
False
)
encode_background_as_zeros
,
image_resizer_fn
,
non_max_suppression_fn
,
tf
.
identity
,
classification_loss
,
localization_loss
,
classific
ation_loss_weight
,
localization_loss_weight
,
normalize_loss_by_num_matches
,
hard_example_miner
,
add_summaries
=
False
)
return
model
,
num_classes
,
mock_anchor_generator
.
num_anchors
(),
code_size
def
test_preprocess_preserves_shapes_with_dynamic_input_image
(
self
):
...
...
research/object_detection/protos/input_reader.proto
View file @
be9b8025
...
...
@@ -32,7 +32,7 @@ message InputReader {
optional
bool
shuffle
=
2
[
default
=
true
];
// Buffer size to be used when shuffling.
optional
uint32
shuffle_buffer_size
=
11
[
default
=
100
];
optional
uint32
shuffle_buffer_size
=
11
[
default
=
2048
];
// Buffer size to be used when shuffling file names.
optional
uint32
filenames_shuffle_buffer_size
=
12
[
default
=
100
];
...
...
@@ -49,10 +49,13 @@ message InputReader {
optional
uint32
num_epochs
=
5
[
default
=
0
];
// Number of reader instances to create.
optional
uint32
num_readers
=
6
[
default
=
8
];
optional
uint32
num_readers
=
6
[
default
=
32
];
// Size of the buffer for prefetching (in batches).
optional
uint32
prefetch_buffer_size
=
13
[
default
=
2
];
// Number of decoded records to prefetch before batching.
optional
uint32
prefetch_size
=
13
[
default
=
512
];
// Number of parallel decode ops to apply.
optional
uint32
num_parallel_map_calls
=
14
[
default
=
64
];
// Whether to load groundtruth instance masks.
optional
bool
load_instance_masks
=
7
[
default
=
false
];
...
...
research/object_detection/protos/ssd.proto
View file @
be9b8025
...
...
@@ -32,6 +32,10 @@ message Ssd {
// Region similarity calculator to compute similarity of boxes.
optional
RegionSimilarityCalculator
similarity_calculator
=
6
;
// Whether background targets are to be encoded as an all
// zeros vector or a one-hot vector (where background is the 0th class).
optional
bool
encode_background_as_zeros
=
12
[
default
=
false
];
// Box predictor to attach to the features.
optional
BoxPredictor
box_predictor
=
7
;
...
...
research/object_detection/samples/configs/faster_rcnn_resnet101_kitti.config
View file @
be9b8025
...
...
@@ -129,7 +129,6 @@ train_input_reader: {
}
eval_config
: {
metrics_set
:
"coco_metrics"
use_moving_averages
:
false
num_examples
:
500
}
...
...
research/object_detection/samples/configs/mask_rcnn_inception_resnet_v2_atrous_coco.config
View file @
be9b8025
...
...
@@ -147,6 +147,8 @@ train_input_reader: {
input_path
:
"PATH_TO_BE_CONFIGURED/mscoco_train.record"
}
label_map_path
:
"PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
load_instance_masks
:
true
mask_type
:
PNG_MASKS
}
eval_config
: {
...
...
@@ -161,6 +163,8 @@ eval_input_reader: {
input_path
:
"PATH_TO_BE_CONFIGURED/mscoco_val.record"
}
label_map_path
:
"PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
load_instance_masks
:
true
mask_type
:
PNG_MASKS
shuffle
:
false
num_readers
:
1
}
research/object_detection/samples/configs/mask_rcnn_inception_v2_coco.config
View file @
be9b8025
...
...
@@ -146,6 +146,8 @@ train_input_reader: {
input_path
:
"PATH_TO_BE_CONFIGURED/mscoco_train.record"
}
label_map_path
:
"PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
load_instance_masks
:
true
mask_type
:
PNG_MASKS
}
eval_config
: {
...
...
@@ -160,6 +162,8 @@ eval_input_reader: {
input_path
:
"PATH_TO_BE_CONFIGURED/mscoco_val.record"
}
label_map_path
:
"PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
load_instance_masks
:
true
mask_type
:
PNG_MASKS
shuffle
:
false
num_readers
:
1
}
research/object_detection/samples/configs/mask_rcnn_resnet101_atrous_coco.config
View file @
be9b8025
...
...
@@ -147,6 +147,8 @@ train_input_reader: {
input_path
:
"PATH_TO_BE_CONFIGURED/mscoco_train.record"
}
label_map_path
:
"PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
load_instance_masks
:
true
mask_type
:
PNG_MASKS
}
eval_config
: {
...
...
@@ -161,6 +163,8 @@ eval_input_reader: {
input_path
:
"PATH_TO_BE_CONFIGURED/mscoco_val.record"
}
label_map_path
:
"PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
load_instance_masks
:
true
mask_type
:
PNG_MASKS
shuffle
:
false
num_readers
:
1
}
research/object_detection/samples/configs/mask_rcnn_resnet101_pets.config
View file @
be9b8025
...
...
@@ -140,6 +140,7 @@ train_input_reader: {
input_path
:
"PATH_TO_BE_CONFIGURED/pet_train.record"
}
label_map_path
:
"PATH_TO_BE_CONFIGURED/pet_label_map.pbtxt"
load_instance_masks
:
true
}
eval_config
: {
...
...
@@ -154,6 +155,7 @@ eval_input_reader: {
input_path
:
"PATH_TO_BE_CONFIGURED/pet_val.record"
}
label_map_path
:
"PATH_TO_BE_CONFIGURED/pet_label_map.pbtxt"
load_instance_masks
:
true
shuffle
:
false
num_readers
:
1
}
research/object_detection/samples/configs/mask_rcnn_resnet50_atrous_coco.config
View file @
be9b8025
...
...
@@ -147,6 +147,8 @@ train_input_reader: {
input_path
:
"PATH_TO_BE_CONFIGURED/mscoco_train.record"
}
label_map_path
:
"PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
load_instance_masks
:
true
mask_type
:
PNG_MASKS
}
eval_config
: {
...
...
@@ -161,6 +163,8 @@ eval_input_reader: {
input_path
:
"PATH_TO_BE_CONFIGURED/mscoco_val.record"
}
label_map_path
:
"PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
load_instance_masks
:
true
mask_type
:
PNG_MASKS
shuffle
:
false
num_readers
:
1
}
research/object_detection/train.py
View file @
be9b8025
...
...
@@ -117,9 +117,7 @@ def main(_):
def
get_next
(
config
):
return
dataset_util
.
make_initializable_iterator
(
dataset_builder
.
build
(
config
,
num_workers
=
FLAGS
.
worker_replicas
,
worker_index
=
FLAGS
.
task
)).
get_next
()
dataset_builder
.
build
(
config
)).
get_next
()
create_input_dict_fn
=
functools
.
partial
(
get_next
,
input_config
)
...
...
research/object_detection/utils/BUILD
View file @
be9b8025
...
...
@@ -264,6 +264,7 @@ py_test(
srcs
=
[
"learning_schedules_test.py"
],
deps
=
[
":learning_schedules"
,
":test_case"
,
"//tensorflow"
,
],
)
...
...
research/object_detection/utils/dataset_util.py
View file @
be9b8025
...
...
@@ -103,9 +103,7 @@ def make_initializable_iterator(dataset):
return
iterator
def
read_dataset
(
file_read_func
,
decode_func
,
input_files
,
config
,
num_workers
=
1
,
worker_index
=
0
):
def
read_dataset
(
file_read_func
,
decode_func
,
input_files
,
config
):
"""Reads a dataset, and handles repetition and shuffling.
Args:
...
...
@@ -114,8 +112,6 @@ def read_dataset(
decode_func: Function to apply to all records.
input_files: A list of file paths to read.
config: A input_reader_builder.InputReader object.
num_workers: Number of workers / shards.
worker_index: Id for the current worker.
Returns:
A tf.data.Dataset based on config.
...
...
@@ -123,25 +119,17 @@ def read_dataset(
# Shard, shuffle, and read files.
filenames
=
tf
.
concat
([
tf
.
matching_files
(
pattern
)
for
pattern
in
input_files
],
0
)
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
filenames
)
dataset
=
dataset
.
shard
(
num_workers
,
worker_index
)
dataset
=
dataset
.
repeat
(
config
.
num_epochs
or
None
)
filename_dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
filenames
)
if
config
.
shuffle
:
dataset
=
dataset
.
shuffle
(
config
.
filenames_shuffle_buffer_size
,
reshuffle_each_iteration
=
True
)
# Read file records and shuffle them.
# If cycle_length is larger than the number of files, more than one reader
# will be assigned to the same file, leading to repetition.
cycle_length
=
tf
.
cast
(
tf
.
minimum
(
config
.
num_readers
,
tf
.
size
(
filenames
)),
tf
.
int64
)
# TODO: find the optimal block_length.
dataset
=
dataset
.
interleave
(
file_read_func
,
cycle_length
=
cycle_length
,
block_length
=
1
)
filename_dataset
=
filename_dataset
.
shuffle
(
config
.
filenames_shuffle_buffer_size
)
filename_dataset
=
filename_dataset
.
repeat
(
config
.
num_epochs
or
None
)
records_dataset
=
filename_dataset
.
apply
(
tf
.
contrib
.
data
.
parallel_interleave
(
file_read_func
,
cycle_length
=
config
.
num_readers
,
sloppy
=
True
))
if
config
.
shuffle
:
dataset
=
dataset
.
shuffle
(
config
.
shuffle_buffer_size
,
reshuffle_each_iteration
=
True
)
dataset
=
dataset
.
map
(
decode_func
,
num_parallel_calls
=
config
.
num_readers
)
return
dataset
.
prefetch
(
config
.
prefetch_buffer_size
)
records_dataset
.
shuffle
(
config
.
shuffle_buffer_size
)
tensor_dataset
=
records_dataset
.
map
(
decode_func
,
num_parallel_calls
=
config
.
num_parallel_map_calls
)
return
tensor_dataset
.
prefetch
(
config
.
prefetch_size
)
research/object_detection/utils/learning_schedules.py
View file @
be9b8025
...
...
@@ -53,10 +53,10 @@ def exponential_decay_with_burnin(global_step,
learning_rate_decay_steps
,
learning_rate_decay_factor
,
staircase
=
True
)
return
tf
.
cond
(
tf
.
less
(
global_step
,
burnin_steps
),
lambda
:
tf
.
convert_to_tensor
(
burnin_learning_rate
),
lambda
:
post_burnin_learning_rate
)
return
tf
.
where
(
tf
.
less
(
tf
.
cast
(
global_step
,
tf
.
int32
),
tf
.
constant
(
burnin_steps
)
)
,
tf
.
constant
(
burnin_learning_rate
),
post_burnin_learning_rate
)
def
cosine_decay_with_warmup
(
global_step
,
...
...
@@ -100,9 +100,10 @@ def cosine_decay_with_warmup(global_step,
slope
=
(
learning_rate_base
-
warmup_learning_rate
)
/
warmup_steps
pre_cosine_learning_rate
=
slope
*
tf
.
cast
(
global_step
,
tf
.
float32
)
+
warmup_learning_rate
learning_rate
=
tf
.
cond
(
tf
.
less
(
global_step
,
warmup_steps
),
lambda
:
pre_cosine_learning_rate
,
lambda
:
learning_rate
)
learning_rate
=
tf
.
where
(
tf
.
less
(
tf
.
cast
(
global_step
,
tf
.
int32
),
warmup_steps
),
pre_cosine_learning_rate
,
learning_rate
)
return
learning_rate
...
...
@@ -141,10 +142,15 @@ def manual_stepping(global_step, boundaries, rates):
if
len
(
rates
)
!=
len
(
boundaries
)
+
1
:
raise
ValueError
(
'Number of provided learning rates must exceed '
'number of boundary points by exactly 1.'
)
step_boundaries
=
tf
.
constant
(
boundaries
,
tf
.
int64
)
step_boundaries
=
tf
.
constant
(
boundaries
,
tf
.
int32
)
num_boundaries
=
len
(
boundaries
)
learning_rates
=
tf
.
constant
(
rates
,
tf
.
float32
)
unreached_boundaries
=
tf
.
reshape
(
tf
.
where
(
tf
.
greater
(
step_boundaries
,
global_step
)),
[
-
1
])
unreached_boundaries
=
tf
.
concat
([
unreached_boundaries
,
[
len
(
boundaries
)]],
0
)
index
=
tf
.
reshape
(
tf
.
reduce_min
(
unreached_boundaries
),
[
1
])
return
tf
.
reshape
(
tf
.
slice
(
learning_rates
,
index
,
[
1
]),
[])
index
=
tf
.
reduce_min
(
tf
.
where
(
# Casting global step to tf.int32 is dangerous, but necessary to be
# compatible with TPU.
tf
.
greater
(
step_boundaries
,
tf
.
cast
(
global_step
,
tf
.
int32
)),
tf
.
constant
(
range
(
num_boundaries
),
dtype
=
tf
.
int32
),
tf
.
constant
([
num_boundaries
]
*
num_boundaries
,
dtype
=
tf
.
int32
)))
return
tf
.
reduce_sum
(
learning_rates
*
tf
.
one_hot
(
index
,
len
(
rates
),
dtype
=
tf
.
float32
))
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment