"include/vscode:/vscode.git/clone" did not exist on "60ecfd73f9d006e02ef2a33820115222f971bd98"
Unverified Commit abfd0698 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

compute small instance weights based on mask area

parent 95f061ba
......@@ -45,6 +45,8 @@ class Parser(hyperparams.Config):
aug_scale_max: float = 1.0
aug_rand_hflip: bool = True
sigma: float = 8.0
small_instance_area_threshold: int = 4096
small_instance_weight: float = 3.0
dtype = 'float32'
@dataclasses.dataclass
......@@ -92,8 +94,8 @@ class PanopticDeeplabPostProcessor(hyperparams.Config):
label_divisor: int = 256 * 256 * 256
stuff_area_limit: int = 4096
ignore_label: int = 0
nms_kernel: int = 41
keep_k_centers: int = 400
nms_kernel: int = 7
keep_k_centers: int = 200
rescale_predictions: bool = True
@dataclasses.dataclass
......@@ -119,7 +121,6 @@ class Losses(hyperparams.Config):
ignore_label: int = 0
class_weights: List[float] = dataclasses.field(default_factory=list)
l2_weight_decay: float = 1e-4
use_groundtruth_dimension: bool = True
top_k_percent_pixels: float = 0.15
segmentation_loss_weight: float = 1.0
center_heatmap_loss_weight: float = 200
......@@ -235,7 +236,6 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
label_smoothing=0.0,
ignore_label=ignore_label,
l2_weight_decay=0.0,
use_groundtruth_dimension=True,
top_k_percent_pixels=0.2,
segmentation_loss_weight=1.0,
center_heatmap_loss_weight=200,
......@@ -248,7 +248,9 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
aug_scale_min=0.5,
aug_scale_max=1.5,
aug_rand_hflip=True,
sigma=8.0)),
sigma=8.0,
small_instance_area_threshold=4096,
small_instance_weight=3.0)),
validation_data=DataConfig(
input_path=os.path.join(_COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
......@@ -259,15 +261,17 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
aug_scale_min=1.0,
aug_scale_max=1.0,
aug_rand_hflip=False,
sigma=8.0),
sigma=8.0,
small_instance_area_threshold=4096,
small_instance_weight=3.0),
drop_remainder=False),
evaluation=Evaluation(
ignored_label=ignore_label,
max_instances_per_category=256,
offset=256 * 256 * 256,
offset=256*256*256,
is_thing=is_thing,
rescale_predictions=True,
report_per_class_pq=True,
report_per_class_pq=False,
report_per_class_iou=False,
report_train_mean_iou=False)),
trainer=cfg.TrainerConfig(
......
......@@ -83,6 +83,8 @@ class Parser(parser.Parser):
aug_scale_min=1.0,
aug_scale_max=1.0,
sigma=8.0,
small_instance_area_threshold=4096,
small_instance_weight=3.0,
dtype='float32'):
"""Initializes parameters for parsing annotations in the dataset.
......@@ -104,6 +106,8 @@ class Parser(parser.Parser):
data augmentation during training.
sigma: `float`, standard deviation for generating 2D Gaussian to encode
centers.
small_instance_area_threshold: `int`,
small_instance_weight: `float`,
dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}.
"""
self._output_size = output_size
......@@ -127,6 +131,9 @@ class Parser(parser.Parser):
self._gaussian, self._gaussian_size = _compute_gaussian_from_std(
self._sigma)
self._gaussian = tf.reshape(self._gaussian, shape=[-1])
self._small_instance_area_threshold = small_instance_area_threshold
self._small_instance_weight = small_instance_weight
def _resize_and_crop_mask(self, mask, image_info, is_training):
"""Resizes and crops mask using `image_info` dict"""
......@@ -194,8 +201,10 @@ class Parser(parser.Parser):
image_info,
is_training=is_training)
instance_centers_heatmap, instance_centers_offset = self._encode_centers_and_offets(
instance_mask=instance_mask[:, :, 0])
(instance_centers_heatmap,
instance_centers_offset,
semantic_weights) = self._encode_centers_and_offets(
instance_mask=instance_mask[:, :, 0])
# Cast image and labels as self._dtype
image = tf.cast(image, dtype=self._dtype)
......@@ -216,6 +225,7 @@ class Parser(parser.Parser):
'instance_mask': instance_mask,
'instance_centers_heatmap': instance_centers_heatmap,
'instance_centers_offset': instance_centers_offset,
'semantic_weights': semantic_weights,
'valid_mask': valid_mask,
'things_mask': things_mask,
'image_info': image_info
......@@ -259,6 +269,9 @@ class Parser(parser.Parser):
centers_offset_x = tf.zeros(
shape=[height, width],
dtype=tf.float32)
semantic_weights = tf.ones(
shape=[height, width],
dtype=tf.float32)
unique_instance_ids, _ = tf.unique(tf.reshape(instance_mask, [-1]))
......@@ -270,11 +283,18 @@ class Parser(parser.Parser):
continue
mask = tf.equal(instance_mask, instance_id)
mask_area = tf.reduce_sum(tf.cast(mask, dtype=tf.float32))
mask_indices = tf.cast(tf.where(mask), dtype=tf.float32)
mask_center = tf.reduce_mean(mask_indices, axis=0)
mask_center_y = tf.cast(tf.round(mask_center[0]), dtype=tf.int32)
mask_center_x = tf.cast(tf.round(mask_center[1]), dtype=tf.int32)
if mask_area < self._small_instance_area_threshold:
semantic_weights = tf.where(
mask,
self._small_instance_weight,
semantic_weights)
gaussian_size = self._gaussian_size
indices_y = tf.range(mask_center_y, mask_center_y + gaussian_size)
indices_x = tf.range(mask_center_x, mask_center_x + gaussian_size)
......@@ -308,4 +328,6 @@ class Parser(parser.Parser):
[centers_offset_y, centers_offset_x],
axis=-1)
return instance_centers_heatmap, instance_centers_offset
return (instance_centers_heatmap,
instance_centers_offset,
semantic_weights)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment