Commit 3616ab28 authored by Vivek Rathod's avatar Vivek Rathod
Browse files

refactor config parsing in train.py binaries and use functions in utils/config_utils.py instead.

parent 9cfd2d93
...@@ -7,7 +7,6 @@ package( ...@@ -7,7 +7,6 @@ package(
licenses(["notice"]) licenses(["notice"])
# Apache 2.0 # Apache 2.0
py_binary( py_binary(
name = "train", name = "train",
srcs = [ srcs = [
...@@ -18,10 +17,7 @@ py_binary( ...@@ -18,10 +17,7 @@ py_binary(
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/builders:input_reader_builder", "//tensorflow_models/object_detection/builders:input_reader_builder",
"//tensorflow_models/object_detection/builders:model_builder", "//tensorflow_models/object_detection/builders:model_builder",
"//tensorflow_models/object_detection/protos:input_reader_py_pb2", "//tensorflow_models/object_detection/utils:config_util",
"//tensorflow_models/object_detection/protos:model_py_pb2",
"//tensorflow_models/object_detection/protos:pipeline_py_pb2",
"//tensorflow_models/object_detection/protos:train_py_pb2",
], ],
) )
...@@ -33,6 +29,7 @@ py_library( ...@@ -33,6 +29,7 @@ py_library(
"//tensorflow_models/object_detection/builders:optimizer_builder", "//tensorflow_models/object_detection/builders:optimizer_builder",
"//tensorflow_models/object_detection/builders:preprocessor_builder", "//tensorflow_models/object_detection/builders:preprocessor_builder",
"//tensorflow_models/object_detection/core:batcher", "//tensorflow_models/object_detection/core:batcher",
"//tensorflow_models/object_detection/core:preprocessor",
"//tensorflow_models/object_detection/core:standard_fields", "//tensorflow_models/object_detection/core:standard_fields",
"//tensorflow_models/object_detection/utils:ops", "//tensorflow_models/object_detection/utils:ops",
"//tensorflow_models/object_detection/utils:variables_helper", "//tensorflow_models/object_detection/utils:variables_helper",
......
...@@ -46,15 +46,10 @@ import json ...@@ -46,15 +46,10 @@ import json
import os import os
import tensorflow as tf import tensorflow as tf
from google.protobuf import text_format
from object_detection import trainer from object_detection import trainer
from object_detection.builders import input_reader_builder from object_detection.builders import input_reader_builder
from object_detection.builders import model_builder from object_detection.builders import model_builder
from object_detection.protos import input_reader_pb2 from object_detection.utils import config_util
from object_detection.protos import model_pb2
from object_detection.protos import pipeline_pb2
from object_detection.protos import train_pb2
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
...@@ -88,61 +83,31 @@ flags.DEFINE_string('model_config_path', '', ...@@ -88,61 +83,31 @@ flags.DEFINE_string('model_config_path', '',
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def get_configs_from_pipeline_file():
"""Reads training configuration from a pipeline_pb2.TrainEvalPipelineConfig.
Reads training config from file specified by pipeline_config_path flag.
Returns:
model_config: model_pb2.DetectionModel
train_config: train_pb2.TrainConfig
input_config: input_reader_pb2.InputReader
"""
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
text_format.Merge(f.read(), pipeline_config)
model_config = pipeline_config.model
train_config = pipeline_config.train_config
input_config = pipeline_config.train_input_reader
return model_config, train_config, input_config
def get_configs_from_multiple_files():
"""Reads training configuration from multiple config files.
Reads the training config from the following files:
model_config: Read from --model_config_path
train_config: Read from --train_config_path
input_config: Read from --input_config_path
Returns:
model_config: model_pb2.DetectionModel
train_config: train_pb2.TrainConfig
input_config: input_reader_pb2.InputReader
"""
train_config = train_pb2.TrainConfig()
with tf.gfile.GFile(FLAGS.train_config_path, 'r') as f:
text_format.Merge(f.read(), train_config)
model_config = model_pb2.DetectionModel()
with tf.gfile.GFile(FLAGS.model_config_path, 'r') as f:
text_format.Merge(f.read(), model_config)
input_config = input_reader_pb2.InputReader()
with tf.gfile.GFile(FLAGS.input_config_path, 'r') as f:
text_format.Merge(f.read(), input_config)
return model_config, train_config, input_config
def main(_): def main(_):
assert FLAGS.train_dir, '`train_dir` is missing.' assert FLAGS.train_dir, '`train_dir` is missing.'
if FLAGS.task == 0: tf.gfile.MakeDirs(FLAGS.train_dir)
if FLAGS.pipeline_config_path: if FLAGS.pipeline_config_path:
model_config, train_config, input_config = get_configs_from_pipeline_file() configs = config_util.get_configs_from_pipeline_file(
FLAGS.pipeline_config_path)
if FLAGS.task == 0:
tf.gfile.Copy(FLAGS.pipeline_config_path,
os.path.join(FLAGS.train_dir, 'pipeline.config'),
overwrite=True)
else: else:
model_config, train_config, input_config = get_configs_from_multiple_files() configs = config_util.get_configs_from_multiple_files(
model_config_path=FLAGS.model_config_path,
train_config_path=FLAGS.train_config_path,
train_input_config_path=FLAGS.input_config_path)
if FLAGS.task == 0:
for name, config in [('model.config', FLAGS.model_config_path),
('train.config', FLAGS.train_config_path),
('input.config', FLAGS.input_config_path)]:
tf.gfile.Copy(config, os.path.join(FLAGS.train_dir, name),
overwrite=True)
model_config = configs['model']
train_config = configs['train_config']
input_config = configs['train_input_config']
model_fn = functools.partial( model_fn = functools.partial(
model_builder.build, model_builder.build,
......
...@@ -35,9 +35,9 @@ from deployment import model_deploy ...@@ -35,9 +35,9 @@ from deployment import model_deploy
slim = tf.contrib.slim slim = tf.contrib.slim
def _create_input_queue(batch_size_per_clone, create_tensor_dict_fn, def create_input_queue(batch_size_per_clone, create_tensor_dict_fn,
batch_queue_capacity, num_batch_queue_threads, batch_queue_capacity, num_batch_queue_threads,
prefetch_queue_capacity, data_augmentation_options): prefetch_queue_capacity, data_augmentation_options):
"""Sets up reader, prefetcher and returns input queue. """Sets up reader, prefetcher and returns input queue.
Args: Args:
...@@ -65,9 +65,16 @@ def _create_input_queue(batch_size_per_clone, create_tensor_dict_fn, ...@@ -65,9 +65,16 @@ def _create_input_queue(batch_size_per_clone, create_tensor_dict_fn,
float_images = tf.to_float(images) float_images = tf.to_float(images)
tensor_dict[fields.InputDataFields.image] = float_images tensor_dict[fields.InputDataFields.image] = float_images
include_instance_masks = (fields.InputDataFields.groundtruth_instance_masks
in tensor_dict)
include_keypoints = (fields.InputDataFields.groundtruth_keypoints
in tensor_dict)
if data_augmentation_options: if data_augmentation_options:
tensor_dict = preprocessor.preprocess(tensor_dict, tensor_dict = preprocessor.preprocess(
data_augmentation_options) tensor_dict, data_augmentation_options,
func_arg_map=preprocessor.get_default_func_arg_map(
include_instance_masks=include_instance_masks,
include_keypoints=include_keypoints))
input_queue = batcher.BatchQueue( input_queue = batcher.BatchQueue(
tensor_dict, tensor_dict,
...@@ -78,56 +85,83 @@ def _create_input_queue(batch_size_per_clone, create_tensor_dict_fn, ...@@ -78,56 +85,83 @@ def _create_input_queue(batch_size_per_clone, create_tensor_dict_fn,
return input_queue return input_queue
def _get_inputs(input_queue, num_classes): def get_inputs(input_queue, num_classes, merge_multiple_label_boxes=False):
"""Dequeue batch and construct inputs to object detection model. """Dequeues batch and constructs inputs to object detection model.
Args: Args:
input_queue: BatchQueue object holding enqueued tensor_dicts. input_queue: BatchQueue object holding enqueued tensor_dicts.
num_classes: Number of classes. num_classes: Number of classes.
merge_multiple_label_boxes: Whether to merge boxes with multiple labels
or not. Defaults to false. Merged boxes are represented with a single
box and a k-hot encoding of the multiple labels associated with the
boxes.
Returns: Returns:
images: a list of 3-D float tensor of images. images: a list of 3-D float tensor of images.
image_keys: a list of string keys for the images.
locations_list: a list of tensors of shape [num_boxes, 4] locations_list: a list of tensors of shape [num_boxes, 4]
containing the corners of the groundtruth boxes. containing the corners of the groundtruth boxes.
classes_list: a list of padded one-hot tensors containing target classes. classes_list: a list of padded one-hot tensors containing target classes.
masks_list: a list of 3-D float tensors of shape [num_boxes, image_height, masks_list: a list of 3-D float tensors of shape [num_boxes, image_height,
image_width] containing instance masks for objects if present in the image_width] containing instance masks for objects if present in the
input_queue. Else returns None. input_queue. Else returns None.
keypoints_list: a list of 3-D float tensors of shape [num_boxes,
num_keypoints, 2] containing keypoints for objects if present in the
input queue. Else returns None.
""" """
read_data_list = input_queue.dequeue() read_data_list = input_queue.dequeue()
label_id_offset = 1 label_id_offset = 1
def extract_images_and_targets(read_data): def extract_images_and_targets(read_data):
"""Extract images and targets from the input dict."""
image = read_data[fields.InputDataFields.image] image = read_data[fields.InputDataFields.image]
key = ''
if fields.InputDataFields.source_id in read_data:
key = read_data[fields.InputDataFields.source_id]
location_gt = read_data[fields.InputDataFields.groundtruth_boxes] location_gt = read_data[fields.InputDataFields.groundtruth_boxes]
classes_gt = tf.cast(read_data[fields.InputDataFields.groundtruth_classes], classes_gt = tf.cast(read_data[fields.InputDataFields.groundtruth_classes],
tf.int32) tf.int32)
classes_gt -= label_id_offset classes_gt -= label_id_offset
classes_gt = util_ops.padded_one_hot_encoding(indices=classes_gt, if merge_multiple_label_boxes:
depth=num_classes, left_pad=0) location_gt, classes_gt, _ = util_ops.merge_boxes_with_multiple_labels(
location_gt, classes_gt, num_classes)
else:
classes_gt = util_ops.padded_one_hot_encoding(
indices=classes_gt, depth=num_classes, left_pad=0)
masks_gt = read_data.get(fields.InputDataFields.groundtruth_instance_masks) masks_gt = read_data.get(fields.InputDataFields.groundtruth_instance_masks)
return image, location_gt, classes_gt, masks_gt keypoints_gt = read_data.get(fields.InputDataFields.groundtruth_keypoints)
if (merge_multiple_label_boxes and (
masks_gt is not None or keypoints_gt is not None)):
raise NotImplementedError('Multi-label support is only for boxes.')
return image, key, location_gt, classes_gt, masks_gt, keypoints_gt
return zip(*map(extract_images_and_targets, read_data_list)) return zip(*map(extract_images_and_targets, read_data_list))
def _create_losses(input_queue, create_model_fn): def _create_losses(input_queue, create_model_fn, train_config):
"""Creates loss function for a DetectionModel. """Creates loss function for a DetectionModel.
Args: Args:
input_queue: BatchQueue object holding enqueued tensor_dicts. input_queue: BatchQueue object holding enqueued tensor_dicts.
create_model_fn: A function to create the DetectionModel. create_model_fn: A function to create the DetectionModel.
train_config: a train_pb2.TrainConfig protobuf.
""" """
detection_model = create_model_fn() detection_model = create_model_fn()
(images, groundtruth_boxes_list, groundtruth_classes_list, (images, _, groundtruth_boxes_list, groundtruth_classes_list,
groundtruth_masks_list groundtruth_masks_list, groundtruth_keypoints_list) = get_inputs(
) = _get_inputs(input_queue, detection_model.num_classes) input_queue,
detection_model.num_classes,
train_config.merge_multiple_label_boxes)
images = [detection_model.preprocess(image) for image in images] images = [detection_model.preprocess(image) for image in images]
images = tf.concat(images, 0) images = tf.concat(images, 0)
if any(mask is None for mask in groundtruth_masks_list): if any(mask is None for mask in groundtruth_masks_list):
groundtruth_masks_list = None groundtruth_masks_list = None
if any(keypoints is None for keypoints in groundtruth_keypoints_list):
groundtruth_keypoints_list = None
detection_model.provide_groundtruth(groundtruth_boxes_list, detection_model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list, groundtruth_classes_list,
groundtruth_masks_list) groundtruth_masks_list,
groundtruth_keypoints_list)
prediction_dict = detection_model.predict(images) prediction_dict = detection_model.predict(images)
losses_dict = detection_model.loss(prediction_dict) losses_dict = detection_model.loss(prediction_dict)
...@@ -176,19 +210,21 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, ...@@ -176,19 +210,21 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
global_step = slim.create_global_step() global_step = slim.create_global_step()
with tf.device(deploy_config.inputs_device()): with tf.device(deploy_config.inputs_device()):
input_queue = _create_input_queue(train_config.batch_size // num_clones, input_queue = create_input_queue(
create_tensor_dict_fn, train_config.batch_size // num_clones, create_tensor_dict_fn,
train_config.batch_queue_capacity, train_config.batch_queue_capacity,
train_config.num_batch_queue_threads, train_config.num_batch_queue_threads,
train_config.prefetch_queue_capacity, train_config.prefetch_queue_capacity, data_augmentation_options)
data_augmentation_options)
# Gather initial summaries. # Gather initial summaries.
# TODO(rathodv): See if summaries can be added/extracted from global tf
# collections so that they don't have to be passed around.
summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
global_summaries = set([]) global_summaries = set([])
model_fn = functools.partial(_create_losses, model_fn = functools.partial(_create_losses,
create_model_fn=create_model_fn) create_model_fn=create_model_fn,
train_config=train_config)
clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue]) clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue])
first_clone_scope = clones[0].scope first_clone_scope = clones[0].scope
......
...@@ -32,6 +32,7 @@ NUMBER_OF_CLASSES = 2 ...@@ -32,6 +32,7 @@ NUMBER_OF_CLASSES = 2
def get_input_function(): def get_input_function():
"""A function to get test inputs. Returns an image with one box.""" """A function to get test inputs. Returns an image with one box."""
image = tf.random_uniform([32, 32, 3], dtype=tf.float32) image = tf.random_uniform([32, 32, 3], dtype=tf.float32)
key = tf.constant('image_000000')
class_label = tf.random_uniform( class_label = tf.random_uniform(
[1], minval=0, maxval=NUMBER_OF_CLASSES, dtype=tf.int32) [1], minval=0, maxval=NUMBER_OF_CLASSES, dtype=tf.int32)
box_label = tf.random_uniform( box_label = tf.random_uniform(
...@@ -39,6 +40,7 @@ def get_input_function(): ...@@ -39,6 +40,7 @@ def get_input_function():
return { return {
fields.InputDataFields.image: image, fields.InputDataFields.image: image,
fields.InputDataFields.key: key,
fields.InputDataFields.groundtruth_classes: class_label, fields.InputDataFields.groundtruth_classes: class_label,
fields.InputDataFields.groundtruth_boxes: box_label fields.InputDataFields.groundtruth_boxes: box_label
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment