"vscode:/vscode.git/clone" did not exist on "0dd5cc753bdc23f3fc1c9ca82170efde4d0c68aa"
Commit 8ea058b9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 313693539
parent 980b27d5
......@@ -14,9 +14,10 @@
# ==============================================================================
"""Factory to provide model configs."""
from official.modeling.hyperparams import params_dict
from official.vision.detection.configs import maskrcnn_config
from official.vision.detection.configs import retinanet_config
from official.modeling.hyperparams import params_dict
from official.vision.detection.configs import shapemask_config
def config_generator(model):
......@@ -27,6 +28,9 @@ def config_generator(model):
elif model == 'mask_rcnn':
default_config = maskrcnn_config.MASKRCNN_CFG
restrictions = maskrcnn_config.MASKRCNN_RESTRICTIONS
elif model == 'shapemask':
default_config = shapemask_config.SHAPEMASK_CFG
restrictions = shapemask_config.SHAPEMASK_RESTRICTIONS
else:
raise ValueError('Model %s is not supported.' % model)
......
......@@ -22,6 +22,7 @@ from official.vision.detection.dataloader import maskrcnn_parser
from official.vision.detection.dataloader import retinanet_parser
from official.vision.detection.dataloader import shapemask_parser
def parser_generator(params, mode):
"""Generator function for various dataset parser."""
if params.architecture.parser == 'retinanet_parser':
......
......@@ -419,6 +419,7 @@ class Parser(object):
inputs = {
'image': image,
'image_info': image_info,
'mask_boxes': sampled_boxes,
'mask_outer_boxes': mask_outer_boxes,
'mask_classes': sampled_classes,
......
......@@ -54,7 +54,7 @@ flags.DEFINE_string(
flags.DEFINE_string(
'model', default='retinanet',
help='Model to run: `retinanet` or `mask_rcnn`.')
help='Model to run: `retinanet`, `mask_rcnn` or `shapemask`.')
flags.DEFINE_string('training_file_pattern', None,
'Location of the train data.')
......@@ -75,7 +75,7 @@ def run_executor(params,
eval_input_fn=None,
callbacks=None,
prebuilt_strategy=None):
"""Runs Retinanet model on distribution strategy defined by the user."""
"""Runs the object detection model on distribution strategy defined by the user."""
if params.architecture.use_bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
......@@ -203,7 +203,7 @@ def run(callbacks=None):
params.lock()
pp = pprint.PrettyPrinter()
params_str = pp.pformat(params.as_dict())
logging.info('Model Parameters: {}'.format(params_str))
logging.info('Model Parameters: %s', params_str)
train_input_fn = None
eval_input_fn = None
......
......@@ -17,6 +17,7 @@
from official.vision.detection.modeling import maskrcnn_model
from official.vision.detection.modeling import retinanet_model
from official.vision.detection.modeling import shapemask_model
def model_generator(params):
......@@ -25,6 +26,8 @@ def model_generator(params):
model_fn = retinanet_model.RetinanetModel(params)
elif params.type == 'mask_rcnn':
model_fn = maskrcnn_model.MaskrcnnModel(params)
elif params.type == 'shapemask':
model_fn = shapemask_model.ShapeMaskModel(params)
else:
raise ValueError('Model %s is not supported.'% params.type)
......
......@@ -411,8 +411,10 @@ class RetinanetClassLoss(object):
bs, height, width, _, _ = cls_targets_one_hot.get_shape().as_list()
cls_targets_one_hot = tf.reshape(cls_targets_one_hot,
[bs, height, width, -1])
loss = focal_loss(cls_outputs, cls_targets_one_hot,
self._focal_loss_alpha, self._focal_loss_gamma,
loss = focal_loss(tf.cast(cls_outputs, dtype=tf.float32),
tf.cast(cls_targets_one_hot, dtype=tf.float32),
self._focal_loss_alpha,
self._focal_loss_gamma,
num_positives)
ignore_loss = tf.where(
......
......@@ -288,6 +288,7 @@ def _generate_detections_batched(boxes,
pad_per_class=False,)
# De-normalizes box cooridinates.
nmsed_boxes *= normalizer
nmsed_classes = tf.cast(nmsed_classes, tf.int32)
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment