Unverified Commit 234011f8 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge pull request #5 from srihari-humbarwadi/training

 * set segmentation loss weight to 0.5
 * skip boxes with zero area while pasting masks
parents d1f01835 b14f82a3
...@@ -107,7 +107,7 @@ class Losses(maskrcnn.Losses): ...@@ -107,7 +107,7 @@ class Losses(maskrcnn.Losses):
semantic_segmentation_use_groundtruth_dimension: bool = True semantic_segmentation_use_groundtruth_dimension: bool = True
semantic_segmentation_top_k_percent_pixels: float = 1.0 semantic_segmentation_top_k_percent_pixels: float = 1.0
instance_segmentation_weight: float = 1.0 instance_segmentation_weight: float = 1.0
semantic_segmentation_weight: float = 1.0 semantic_segmentation_weight: float = 0.5
@dataclasses.dataclass @dataclasses.dataclass
...@@ -170,7 +170,8 @@ def panoptic_fpn_coco() -> cfg.ExperimentConfig: ...@@ -170,7 +170,8 @@ def panoptic_fpn_coco() -> cfg.ExperimentConfig:
is_thing.append(True if idx <= num_thing_categories else False) is_thing.append(True if idx <= num_thing_categories else False)
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'), runtime=cfg.RuntimeConfig(
mixed_precision_dtype='bfloat16', enable_xla=True),
task=PanopticMaskRCNNTask( task=PanopticMaskRCNNTask(
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080', # pylint: disable=line-too-long init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080', # pylint: disable=line-too-long
init_checkpoint_modules=['backbone'], init_checkpoint_modules=['backbone'],
......
...@@ -79,26 +79,27 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -79,26 +79,27 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
pasted_mask = tf.ones( pasted_mask = tf.ones(
self._output_size + [1], dtype=mask.dtype) * self._void_class_label self._output_size + [1], dtype=mask.dtype) * self._void_class_label
ymin = box[0] ymin = tf.clip_by_value(box[0], 0, self._output_size[0])
xmin = box[1] xmin = tf.clip_by_value(box[1], 0, self._output_size[1])
ymax = tf.clip_by_value(box[2] + 1, 0, self._output_size[0]) ymax = tf.clip_by_value(box[2] + 1, 0, self._output_size[0])
xmax = tf.clip_by_value(box[3] + 1, 0, self._output_size[1]) xmax = tf.clip_by_value(box[3] + 1, 0, self._output_size[1])
box_height = ymax - ymin box_height = ymax - ymin
box_width = xmax - xmin box_width = xmax - xmin
# resize mask to match the shape of the instance bounding box if not (box_height == 0 or box_width == 0):
resized_mask = tf.image.resize( # resize mask to match the shape of the instance bounding box
mask, resized_mask = tf.image.resize(
size=(box_height, box_width), mask,
method='nearest') size=(box_height, box_width),
method='nearest')
# paste resized mask on a blank mask that matches image shape
pasted_mask = tf.raw_ops.TensorStridedSliceUpdate( # paste resized mask on a blank mask that matches image shape
input=pasted_mask, pasted_mask = tf.raw_ops.TensorStridedSliceUpdate(
begin=[ymin, xmin], input=pasted_mask,
end=[ymax, xmax], begin=[ymin, xmin],
strides=[1, 1], end=[ymax, xmax],
value=resized_mask) strides=[1, 1],
value=resized_mask)
return pasted_mask return pasted_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment