"docs/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "a697002cfbb4e7d2fb8a1a4646d958bc57a2d973"
Commit 8d9a16ce authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 285844156
parent 913640d4
...@@ -30,10 +30,11 @@ REGULARIZATION_VAR_REGEX = r'.*(kernel|weight):0$' ...@@ -30,10 +30,11 @@ REGULARIZATION_VAR_REGEX = r'.*(kernel|weight):0$'
BASE_CFG = { BASE_CFG = {
'model_dir': '', 'model_dir': '',
'use_tpu': True, 'use_tpu': True,
'strategy_type': 'tpu',
'isolate_session_state': False, 'isolate_session_state': False,
'train': { 'train': {
'iterations_per_loop': 100, 'iterations_per_loop': 100,
'train_batch_size': 64, 'batch_size': 64,
'total_steps': 22500, 'total_steps': 22500,
'num_cores_per_replica': None, 'num_cores_per_replica': None,
'input_partition_dims': None, 'input_partition_dims': None,
...@@ -57,13 +58,13 @@ BASE_CFG = { ...@@ -57,13 +58,13 @@ BASE_CFG = {
'frozen_variable_prefix': RESNET_FROZEN_VAR_PREFIX, 'frozen_variable_prefix': RESNET_FROZEN_VAR_PREFIX,
'train_file_pattern': '', 'train_file_pattern': '',
'train_dataset_type': 'tfrecord', 'train_dataset_type': 'tfrecord',
'transpose_input': True, 'transpose_input': False,
'regularization_variable_regex': REGULARIZATION_VAR_REGEX, 'regularization_variable_regex': REGULARIZATION_VAR_REGEX,
'l2_weight_decay': 0.0001, 'l2_weight_decay': 0.0001,
'gradient_clip_norm': 0.0, 'gradient_clip_norm': 0.0,
}, },
'eval': { 'eval': {
'eval_batch_size': 8, 'batch_size': 8,
'eval_samples': 5000, 'eval_samples': 5000,
'min_eval_interval': 180, 'min_eval_interval': 180,
'eval_timeout': None, 'eval_timeout': None,
......
...@@ -34,6 +34,7 @@ MASKRCNN_CFG.override({ ...@@ -34,6 +34,7 @@ MASKRCNN_CFG.override({
'maskrcnn_parser': { 'maskrcnn_parser': {
'use_bfloat16': True, 'use_bfloat16': True,
'output_size': [1024, 1024], 'output_size': [1024, 1024],
'num_channels': 3,
'rpn_match_threshold': 0.7, 'rpn_match_threshold': 0.7,
'rpn_unmatched_threshold': 0.3, 'rpn_unmatched_threshold': 0.3,
'rpn_batch_size_per_im': 256, 'rpn_batch_size_per_im': 256,
......
...@@ -275,6 +275,10 @@ class Parser(object): ...@@ -275,6 +275,10 @@ class Parser(object):
if self._use_bfloat16: if self._use_bfloat16:
image = tf.cast(image, dtype=tf.bfloat16) image = tf.cast(image, dtype=tf.bfloat16)
inputs = {
'image': image,
'image_info': image_info,
}
# Packs labels for model_fn outputs. # Packs labels for model_fn outputs.
labels = { labels = {
'anchor_boxes': input_anchor.multilevel_boxes, 'anchor_boxes': input_anchor.multilevel_boxes,
...@@ -282,15 +286,16 @@ class Parser(object): ...@@ -282,15 +286,16 @@ class Parser(object):
'rpn_score_targets': rpn_score_targets, 'rpn_score_targets': rpn_score_targets,
'rpn_box_targets': rpn_box_targets, 'rpn_box_targets': rpn_box_targets,
} }
labels['gt_boxes'] = input_utils.pad_to_fixed_size( inputs['gt_boxes'] = input_utils.pad_to_fixed_size(boxes,
boxes, self._max_num_instances, -1) self._max_num_instances,
labels['gt_classes'] = input_utils.pad_to_fixed_size( -1)
inputs['gt_classes'] = input_utils.pad_to_fixed_size(
classes, self._max_num_instances, -1) classes, self._max_num_instances, -1)
if self._include_mask: if self._include_mask:
labels['gt_masks'] = input_utils.pad_to_fixed_size( inputs['gt_masks'] = input_utils.pad_to_fixed_size(
masks, self._max_num_instances, -1) masks, self._max_num_instances, -1)
return image, labels return inputs, labels
def _parse_eval_data(self, data): def _parse_eval_data(self, data):
"""Parses data for evaluation.""" """Parses data for evaluation."""
...@@ -348,11 +353,7 @@ class Parser(object): ...@@ -348,11 +353,7 @@ class Parser(object):
self._anchor_size, self._anchor_size,
(image_height, image_width)) (image_height, image_width))
labels = { labels = {}
'source_id': dataloader_utils.process_source_id(data['source_id']),
'anchor_boxes': input_anchor.multilevel_boxes,
'image_info': image_info,
}
if self._mode == ModeKeys.PREDICT_WITH_GT: if self._mode == ModeKeys.PREDICT_WITH_GT:
# Converts boxes from normalized coordinates to pixel coordinates. # Converts boxes from normalized coordinates to pixel coordinates.
...@@ -372,6 +373,11 @@ class Parser(object): ...@@ -372,6 +373,11 @@ class Parser(object):
groundtruths['source_id']) groundtruths['source_id'])
groundtruths = dataloader_utils.pad_groundtruths_to_fixed_size( groundtruths = dataloader_utils.pad_groundtruths_to_fixed_size(
groundtruths, self._max_num_instances) groundtruths, self._max_num_instances)
# TODO(yeqing): Remove the `groundtrtuh` layer key (no longer needed).
labels['groundtruths'] = groundtruths labels['groundtruths'] = groundtruths
inputs = {
'image': image,
'image_info': image_info,
}
return image, labels return inputs, labels
...@@ -99,6 +99,7 @@ class Model(object): ...@@ -99,6 +99,7 @@ class Model(object):
params.train.learning_rate) params.train.learning_rate)
self._frozen_variable_prefix = params.train.frozen_variable_prefix self._frozen_variable_prefix = params.train.frozen_variable_prefix
self._l2_weight_decay = params.train.l2_weight_decay
# Checkpoint restoration. # Checkpoint restoration.
self._checkpoint = params.train.checkpoint.as_dict() self._checkpoint = params.train.checkpoint.as_dict()
......
...@@ -147,6 +147,7 @@ class RpnBoxLoss(object): ...@@ -147,6 +147,7 @@ class RpnBoxLoss(object):
"""Region Proposal Network box regression loss function.""" """Region Proposal Network box regression loss function."""
def __init__(self, params): def __init__(self, params):
self._delta = params.huber_loss_delta
self._huber_loss = tf.keras.losses.Huber( self._huber_loss = tf.keras.losses.Huber(
delta=params.huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM) delta=params.huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM)
...@@ -212,7 +213,7 @@ class FastrcnnClassLoss(object): ...@@ -212,7 +213,7 @@ class FastrcnnClassLoss(object):
a scalar tensor representing total class loss. a scalar tensor representing total class loss.
""" """
with tf.name_scope('fast_rcnn_loss'): with tf.name_scope('fast_rcnn_loss'):
_, _, _, num_classes = class_outputs.get_shape().as_list() _, _, num_classes = class_outputs.get_shape().as_list()
class_targets = tf.cast(class_targets, dtype=tf.int32) class_targets = tf.cast(class_targets, dtype=tf.int32)
class_targets_one_hot = tf.one_hot(class_targets, num_classes) class_targets_one_hot = tf.one_hot(class_targets, num_classes)
return self._fast_rcnn_class_loss(class_outputs, class_targets_one_hot) return self._fast_rcnn_class_loss(class_outputs, class_targets_one_hot)
...@@ -320,9 +321,6 @@ class FastrcnnBoxLoss(object): ...@@ -320,9 +321,6 @@ class FastrcnnBoxLoss(object):
class MaskrcnnLoss(object): class MaskrcnnLoss(object):
"""Mask R-CNN instance segmentation mask loss function.""" """Mask R-CNN instance segmentation mask loss function."""
def __init__(self):
raise ValueError('Not TF 2.0 ready.')
def __call__(self, mask_outputs, mask_targets, select_class_targets): def __call__(self, mask_outputs, mask_targets, select_class_targets):
"""Computes the mask loss of Mask-RCNN. """Computes the mask loss of Mask-RCNN.
......
...@@ -56,7 +56,6 @@ class RetinanetModel(base_model.Model): ...@@ -56,7 +56,6 @@ class RetinanetModel(base_model.Model):
self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator( self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
params.postprocess) params.postprocess)
self._l2_weight_decay = params.train.l2_weight_decay
self._transpose_input = params.train.transpose_input self._transpose_input = params.train.transpose_input
assert not self._transpose_input, 'Transpose input is not supportted.' assert not self._transpose_input, 'Transpose input is not supportted.'
# Input layer. # Input layer.
...@@ -134,6 +133,7 @@ class RetinanetModel(base_model.Model): ...@@ -134,6 +133,7 @@ class RetinanetModel(base_model.Model):
return self._keras_model return self._keras_model
def post_processing(self, labels, outputs): def post_processing(self, labels, outputs):
# TODO(yeqing): Moves the output related part into build_outputs.
required_output_fields = ['cls_outputs', 'box_outputs'] required_output_fields = ['cls_outputs', 'box_outputs']
for field in required_output_fields: for field in required_output_fields:
if field not in outputs: if field not in outputs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment