Commit 8ea058b9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 313693539
parent 980b27d5
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
# ============================================================================== # ==============================================================================
"""Factory to provide model configs.""" """Factory to provide model configs."""
from official.modeling.hyperparams import params_dict
from official.vision.detection.configs import maskrcnn_config from official.vision.detection.configs import maskrcnn_config
from official.vision.detection.configs import retinanet_config from official.vision.detection.configs import retinanet_config
from official.modeling.hyperparams import params_dict from official.vision.detection.configs import shapemask_config
def config_generator(model): def config_generator(model):
...@@ -27,6 +28,9 @@ def config_generator(model): ...@@ -27,6 +28,9 @@ def config_generator(model):
elif model == 'mask_rcnn': elif model == 'mask_rcnn':
default_config = maskrcnn_config.MASKRCNN_CFG default_config = maskrcnn_config.MASKRCNN_CFG
restrictions = maskrcnn_config.MASKRCNN_RESTRICTIONS restrictions = maskrcnn_config.MASKRCNN_RESTRICTIONS
elif model == 'shapemask':
default_config = shapemask_config.SHAPEMASK_CFG
restrictions = shapemask_config.SHAPEMASK_RESTRICTIONS
else: else:
raise ValueError('Model %s is not supported.' % model) raise ValueError('Model %s is not supported.' % model)
......
...@@ -22,6 +22,7 @@ from official.vision.detection.dataloader import maskrcnn_parser ...@@ -22,6 +22,7 @@ from official.vision.detection.dataloader import maskrcnn_parser
from official.vision.detection.dataloader import retinanet_parser from official.vision.detection.dataloader import retinanet_parser
from official.vision.detection.dataloader import shapemask_parser from official.vision.detection.dataloader import shapemask_parser
def parser_generator(params, mode): def parser_generator(params, mode):
"""Generator function for various dataset parser.""" """Generator function for various dataset parser."""
if params.architecture.parser == 'retinanet_parser': if params.architecture.parser == 'retinanet_parser':
......
...@@ -419,6 +419,7 @@ class Parser(object): ...@@ -419,6 +419,7 @@ class Parser(object):
inputs = { inputs = {
'image': image, 'image': image,
'image_info': image_info,
'mask_boxes': sampled_boxes, 'mask_boxes': sampled_boxes,
'mask_outer_boxes': mask_outer_boxes, 'mask_outer_boxes': mask_outer_boxes,
'mask_classes': sampled_classes, 'mask_classes': sampled_classes,
......
...@@ -54,7 +54,7 @@ flags.DEFINE_string( ...@@ -54,7 +54,7 @@ flags.DEFINE_string(
flags.DEFINE_string( flags.DEFINE_string(
'model', default='retinanet', 'model', default='retinanet',
help='Model to run: `retinanet` or `mask_rcnn`.') help='Model to run: `retinanet`, `mask_rcnn` or `shapemask`.')
flags.DEFINE_string('training_file_pattern', None, flags.DEFINE_string('training_file_pattern', None,
'Location of the train data.') 'Location of the train data.')
...@@ -75,7 +75,7 @@ def run_executor(params, ...@@ -75,7 +75,7 @@ def run_executor(params,
eval_input_fn=None, eval_input_fn=None,
callbacks=None, callbacks=None,
prebuilt_strategy=None): prebuilt_strategy=None):
"""Runs Retinanet model on distribution strategy defined by the user.""" """Runs the object detection model on distribution strategy defined by the user."""
if params.architecture.use_bfloat16: if params.architecture.use_bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
...@@ -203,7 +203,7 @@ def run(callbacks=None): ...@@ -203,7 +203,7 @@ def run(callbacks=None):
params.lock() params.lock()
pp = pprint.PrettyPrinter() pp = pprint.PrettyPrinter()
params_str = pp.pformat(params.as_dict()) params_str = pp.pformat(params.as_dict())
logging.info('Model Parameters: {}'.format(params_str)) logging.info('Model Parameters: %s', params_str)
train_input_fn = None train_input_fn = None
eval_input_fn = None eval_input_fn = None
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
from official.vision.detection.modeling import maskrcnn_model from official.vision.detection.modeling import maskrcnn_model
from official.vision.detection.modeling import retinanet_model from official.vision.detection.modeling import retinanet_model
from official.vision.detection.modeling import shapemask_model
def model_generator(params): def model_generator(params):
...@@ -25,6 +26,8 @@ def model_generator(params): ...@@ -25,6 +26,8 @@ def model_generator(params):
model_fn = retinanet_model.RetinanetModel(params) model_fn = retinanet_model.RetinanetModel(params)
elif params.type == 'mask_rcnn': elif params.type == 'mask_rcnn':
model_fn = maskrcnn_model.MaskrcnnModel(params) model_fn = maskrcnn_model.MaskrcnnModel(params)
elif params.type == 'shapemask':
model_fn = shapemask_model.ShapeMaskModel(params)
else: else:
raise ValueError('Model %s is not supported.'% params.type) raise ValueError('Model %s is not supported.'% params.type)
......
...@@ -411,8 +411,10 @@ class RetinanetClassLoss(object): ...@@ -411,8 +411,10 @@ class RetinanetClassLoss(object):
bs, height, width, _, _ = cls_targets_one_hot.get_shape().as_list() bs, height, width, _, _ = cls_targets_one_hot.get_shape().as_list()
cls_targets_one_hot = tf.reshape(cls_targets_one_hot, cls_targets_one_hot = tf.reshape(cls_targets_one_hot,
[bs, height, width, -1]) [bs, height, width, -1])
loss = focal_loss(cls_outputs, cls_targets_one_hot, loss = focal_loss(tf.cast(cls_outputs, dtype=tf.float32),
self._focal_loss_alpha, self._focal_loss_gamma, tf.cast(cls_targets_one_hot, dtype=tf.float32),
self._focal_loss_alpha,
self._focal_loss_gamma,
num_positives) num_positives)
ignore_loss = tf.where( ignore_loss = tf.where(
......
...@@ -288,6 +288,7 @@ def _generate_detections_batched(boxes, ...@@ -288,6 +288,7 @@ def _generate_detections_batched(boxes,
pad_per_class=False,) pad_per_class=False,)
# De-normalizes box cooridinates. # De-normalizes box cooridinates.
nmsed_boxes *= normalizer nmsed_boxes *= normalizer
nmsed_classes = tf.cast(nmsed_classes, tf.int32)
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment