"...git@developer.sourcefind.cn:modelzoo/bert_migraphx.git" did not exist on "1d125612cc711d5e3939c9155b1ae34e82f6b2b1"
Commit cc08dc87 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Update to trainer to allow for reading multiclass scores

PiperOrigin-RevId: 192624207
parent 1bddd18e
...@@ -69,10 +69,13 @@ def create_input_queue(batch_size_per_clone, create_tensor_dict_fn, ...@@ -69,10 +69,13 @@ def create_input_queue(batch_size_per_clone, create_tensor_dict_fn,
in tensor_dict) in tensor_dict)
include_keypoints = (fields.InputDataFields.groundtruth_keypoints include_keypoints = (fields.InputDataFields.groundtruth_keypoints
in tensor_dict) in tensor_dict)
include_multiclass_scores = (fields.InputDataFields.multiclass_scores
in tensor_dict)
if data_augmentation_options: if data_augmentation_options:
tensor_dict = preprocessor.preprocess( tensor_dict = preprocessor.preprocess(
tensor_dict, data_augmentation_options, tensor_dict, data_augmentation_options,
func_arg_map=preprocessor.get_default_func_arg_map( func_arg_map=preprocessor.get_default_func_arg_map(
include_multiclass_scores=include_multiclass_scores,
include_instance_masks=include_instance_masks, include_instance_masks=include_instance_masks,
include_keypoints=include_keypoints)) include_keypoints=include_keypoints))
...@@ -85,7 +88,10 @@ def create_input_queue(batch_size_per_clone, create_tensor_dict_fn, ...@@ -85,7 +88,10 @@ def create_input_queue(batch_size_per_clone, create_tensor_dict_fn,
return input_queue return input_queue
def get_inputs(input_queue, num_classes, merge_multiple_label_boxes=False): def get_inputs(input_queue,
num_classes,
merge_multiple_label_boxes=False,
use_multiclass_scores=False):
"""Dequeues batch and constructs inputs to object detection model. """Dequeues batch and constructs inputs to object detection model.
Args: Args:
...@@ -95,6 +101,8 @@ def get_inputs(input_queue, num_classes, merge_multiple_label_boxes=False): ...@@ -95,6 +101,8 @@ def get_inputs(input_queue, num_classes, merge_multiple_label_boxes=False):
or not. Defaults to false. Merged boxes are represented with a single or not. Defaults to false. Merged boxes are represented with a single
box and a k-hot encoding of the multiple labels associated with the box and a k-hot encoding of the multiple labels associated with the
boxes. boxes.
use_multiclass_scores: Whether to use multiclass scores instead of
groundtruth_classes.
Returns: Returns:
images: a list of 3-D float tensor of images. images: a list of 3-D float tensor of images.
...@@ -123,9 +131,19 @@ def get_inputs(input_queue, num_classes, merge_multiple_label_boxes=False): ...@@ -123,9 +131,19 @@ def get_inputs(input_queue, num_classes, merge_multiple_label_boxes=False):
classes_gt = tf.cast(read_data[fields.InputDataFields.groundtruth_classes], classes_gt = tf.cast(read_data[fields.InputDataFields.groundtruth_classes],
tf.int32) tf.int32)
classes_gt -= label_id_offset classes_gt -= label_id_offset
if merge_multiple_label_boxes and use_multiclass_scores:
raise ValueError(
'Using both merge_multiple_label_boxes and use_multiclass_scores is'
'not supported'
)
if merge_multiple_label_boxes: if merge_multiple_label_boxes:
location_gt, classes_gt, _ = util_ops.merge_boxes_with_multiple_labels( location_gt, classes_gt, _ = util_ops.merge_boxes_with_multiple_labels(
location_gt, classes_gt, num_classes) location_gt, classes_gt, num_classes)
elif use_multiclass_scores:
classes_gt = tf.cast(read_data[fields.InputDataFields.multiclass_scores],
tf.float32)
else: else:
classes_gt = util_ops.padded_one_hot_encoding( classes_gt = util_ops.padded_one_hot_encoding(
indices=classes_gt, depth=num_classes, left_pad=0) indices=classes_gt, depth=num_classes, left_pad=0)
...@@ -155,7 +173,8 @@ def _create_losses(input_queue, create_model_fn, train_config): ...@@ -155,7 +173,8 @@ def _create_losses(input_queue, create_model_fn, train_config):
groundtruth_masks_list, groundtruth_keypoints_list, _) = get_inputs( groundtruth_masks_list, groundtruth_keypoints_list, _) = get_inputs(
input_queue, input_queue,
detection_model.num_classes, detection_model.num_classes,
train_config.merge_multiple_label_boxes) train_config.merge_multiple_label_boxes,
train_config.use_multiclass_scores)
preprocessed_images = [] preprocessed_images = []
true_image_shapes = [] true_image_shapes = []
...@@ -183,9 +202,19 @@ def _create_losses(input_queue, create_model_fn, train_config): ...@@ -183,9 +202,19 @@ def _create_losses(input_queue, create_model_fn, train_config):
tf.losses.add_loss(loss_tensor) tf.losses.add_loss(loss_tensor)
def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, def train(create_tensor_dict_fn,
num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name, create_model_fn,
is_chief, train_dir, graph_hook_fn=None): train_config,
master,
task,
num_clones,
worker_replicas,
clone_on_cpu,
ps_tasks,
worker_job_name,
is_chief,
train_dir,
graph_hook_fn=None):
"""Training function for detection models. """Training function for detection models.
Args: Args:
......
...@@ -37,12 +37,15 @@ def get_input_function(): ...@@ -37,12 +37,15 @@ def get_input_function():
[1], minval=0, maxval=NUMBER_OF_CLASSES, dtype=tf.int32) [1], minval=0, maxval=NUMBER_OF_CLASSES, dtype=tf.int32)
box_label = tf.random_uniform( box_label = tf.random_uniform(
[1, 4], minval=0.4, maxval=0.6, dtype=tf.float32) [1, 4], minval=0.4, maxval=0.6, dtype=tf.float32)
multiclass_scores = tf.random_uniform(
[1, NUMBER_OF_CLASSES], minval=0.4, maxval=0.6, dtype=tf.float32)
return { return {
fields.InputDataFields.image: image, fields.InputDataFields.image: image,
fields.InputDataFields.key: key, fields.InputDataFields.key: key,
fields.InputDataFields.groundtruth_classes: class_label, fields.InputDataFields.groundtruth_classes: class_label,
fields.InputDataFields.groundtruth_boxes: box_label fields.InputDataFields.groundtruth_boxes: box_label,
fields.InputDataFields.multiclass_scores: multiclass_scores
} }
...@@ -203,6 +206,50 @@ class TrainerTest(tf.test.TestCase): ...@@ -203,6 +206,50 @@ class TrainerTest(tf.test.TestCase):
train_dir = self.get_temp_dir() train_dir = self.get_temp_dir()
trainer.train(
create_tensor_dict_fn=get_input_function,
create_model_fn=FakeDetectionModel,
train_config=train_config,
master='',
task=0,
num_clones=1,
worker_replicas=1,
clone_on_cpu=True,
ps_tasks=0,
worker_job_name='worker',
is_chief=True,
train_dir=train_dir)
def test_configure_trainer_with_multiclass_scores_and_train_two_steps(self):
train_config_text_proto = """
optimizer {
adam_optimizer {
learning_rate {
constant_learning_rate {
learning_rate: 0.01
}
}
}
}
data_augmentation_options {
random_adjust_brightness {
max_delta: 0.2
}
}
data_augmentation_options {
random_adjust_contrast {
min_delta: 0.7
max_delta: 1.1
}
}
num_steps: 2
use_multiclass_scores: true
"""
train_config = train_pb2.TrainConfig()
text_format.Merge(train_config_text_proto, train_config)
train_dir = self.get_temp_dir()
trainer.train(create_tensor_dict_fn=get_input_function, trainer.train(create_tensor_dict_fn=get_input_function,
create_model_fn=FakeDetectionModel, create_model_fn=FakeDetectionModel,
train_config=train_config, train_config=train_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment