Commit 324d6dc3 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Merged commit includes the following changes:

196161788  by Zhichao Lu:

    Add eval_on_train_steps parameter.

    Since the number of samples in train dataset is usually different to the number of samples in the eval dataset.

--
196151742  by Zhichao Lu:

    Add an optional random sampling process for SSD meta arch and update mean stddev coder to use default std dev when corresponding tensor is not added to boxlist field.

--
196148940  by Zhichao Lu:

    Release ssdlite mobilenet v2 coco trained model.

--
196058528  by Zhichao Lu:

    Apply FPN feature map generation before we add additional layers on top of resnet feature extractor.

--
195818367  by Zhichao Lu:

    Add support for exporting detection keypoints.

--
195745420  by Zhichao Lu:

    Introduce include_metrics_per_category option to Object Detection eval_config.

--
195734733  by Zhichao Lu:

    Rename SSDLite config to be more explicit.

--
195717383  by Zhichao Lu:

    Add quantized training to object_detection.

--
195683542  by...
parent 63054210
......@@ -22,6 +22,7 @@ from google.protobuf import text_format
from tensorflow.python.lib.io import file_io
from object_detection.protos import eval_pb2
from object_detection.protos import graph_rewriter_pb2
from object_detection.protos import input_reader_pb2
from object_detection.protos import model_pb2
from object_detection.protos import pipeline_pb2
......@@ -111,9 +112,27 @@ def create_configs_from_pipeline_proto(pipeline_config):
configs["train_input_config"] = pipeline_config.train_input_reader
configs["eval_config"] = pipeline_config.eval_config
configs["eval_input_config"] = pipeline_config.eval_input_reader
if pipeline_config.HasField("graph_rewriter"):
configs["graph_rewriter_config"] = pipeline_config.graph_rewriter
return configs
def get_graph_rewriter_config_from_file(graph_rewriter_config_file):
"""Parses config for graph rewriter.
Args:
graph_rewriter_config_file: file path to the graph rewriter config.
Returns:
graph_rewriter_pb2.GraphRewriter proto
"""
graph_rewriter_config = graph_rewriter_pb2.GraphRewriter()
with tf.gfile.GFile(graph_rewriter_config_file, "r") as f:
text_format.Merge(f.read(), graph_rewriter_config)
return graph_rewriter_config
def create_pipeline_proto_from_configs(configs):
"""Creates a pipeline_pb2.TrainEvalPipelineConfig from configs dictionary.
......@@ -132,6 +151,8 @@ def create_pipeline_proto_from_configs(configs):
pipeline_config.train_input_reader.CopyFrom(configs["train_input_config"])
pipeline_config.eval_config.CopyFrom(configs["eval_config"])
pipeline_config.eval_input_reader.CopyFrom(configs["eval_input_config"])
if "graph_rewriter_config" in configs:
pipeline_config.graph_rewriter.CopyFrom(configs["graph_rewriter_config"])
return pipeline_config
......@@ -157,7 +178,8 @@ def get_configs_from_multiple_files(model_config_path="",
train_config_path="",
train_input_config_path="",
eval_config_path="",
eval_input_config_path=""):
eval_input_config_path="",
graph_rewriter_config_path=""):
"""Reads training configuration from multiple config files.
Args:
......@@ -166,6 +188,7 @@ def get_configs_from_multiple_files(model_config_path="",
train_input_config_path: Path to input_reader_pb2.InputReader.
eval_config_path: Path to eval_pb2.EvalConfig.
eval_input_config_path: Path to input_reader_pb2.InputReader.
graph_rewriter_config_path: Path to graph_rewriter_pb2.GraphRewriter.
Returns:
Dictionary of configuration objects. Keys are `model`, `train_config`,
......@@ -203,6 +226,10 @@ def get_configs_from_multiple_files(model_config_path="",
text_format.Merge(f.read(), eval_input_config)
configs["eval_input_config"] = eval_input_config
if graph_rewriter_config_path:
configs["graph_rewriter_config"] = get_graph_rewriter_config_from_file(
graph_rewriter_config_path)
return configs
......
......@@ -132,7 +132,7 @@ def read_dataset(file_read_func, decode_func, input_files, config):
records_dataset = filename_dataset.apply(
tf.contrib.data.parallel_interleave(
file_read_func, cycle_length=config.num_readers,
block_length=config.read_block_length, sloppy=True))
block_length=config.read_block_length, sloppy=config.shuffle))
if config.shuffle:
records_dataset.shuffle(config.shuffle_buffer_size)
tensor_dataset = records_dataset.map(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment