Unverified Commit ca552843 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-segmentation

parents 7e2f7a35 6b90e134
......@@ -30,8 +30,10 @@ class ResNet(hyperparams.Config):
stem_type: str = 'v0'
se_ratio: float = 0.0
stochastic_depth_drop_rate: float = 0.0
scale_stem: bool = True
resnetd_shortcut: bool = False
replace_stem_max_pool: bool = False
bn_trainable: bool = True
@dataclasses.dataclass
......
......@@ -15,15 +15,44 @@
# Lint as: python3
"""Common configurations."""
import dataclasses
from typing import Optional
# Import libraries
import dataclasses
# Import libraries
from official.core import config_definitions as cfg
from official.modeling import hyperparams
@dataclasses.dataclass
class TfExampleDecoder(hyperparams.Config):
"""A simple TF Example decoder config."""
regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
@dataclasses.dataclass
class TfExampleDecoderLabelMap(hyperparams.Config):
"""TF Example decoder with label map config."""
regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
label_map: str = ''
@dataclasses.dataclass
class DataDecoder(hyperparams.OneOfConfig):
"""Data decoder config.
Attributes:
type: 'str', type of data decoder be used, one of the fields below.
simple_decoder: simple TF Example decoder config.
label_map_decoder: TF Example decoder with label map config.
"""
type: Optional[str] = 'simple_decoder'
simple_decoder: TfExampleDecoder = TfExampleDecoder()
label_map_decoder: TfExampleDecoderLabelMap = TfExampleDecoderLabelMap()
@dataclasses.dataclass
class RandAugment(hyperparams.Config):
"""Configuration for RandAugment."""
......
......@@ -32,6 +32,7 @@ class Identity(hyperparams.Config):
class FPN(hyperparams.Config):
"""FPN config."""
num_filters: int = 256
fusion_type: str = 'sum'
use_separable_conv: bool = False
......@@ -50,6 +51,7 @@ class ASPP(hyperparams.Config):
dilation_rates: List[int] = dataclasses.field(default_factory=list)
dropout_rate: float = 0.0
num_filters: int = 256
use_depthwise_convolution: bool = False
pool_kernel_size: Optional[List[int]] = None # Use global average pooling.
......
......@@ -14,11 +14,10 @@
# Lint as: python3
"""Image classification configuration definition."""
import dataclasses
import os
from typing import List, Optional
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
......@@ -44,6 +43,7 @@ class DataConfig(cfg.DataConfig):
image_field_key: str = 'image/encoded'
label_field_key: str = 'image/class/label'
decode_jpeg_only: bool = True
decoder: Optional[common.DataDecoder] = common.DataDecoder()
# Keep for backward compatibility.
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'.
......
......@@ -13,11 +13,11 @@
# limitations under the License.
# Lint as: python3
"""Mask R-CNN configuration definition."""
"""R-CNN(-RS) configuration definition."""
import dataclasses
import os
from typing import List, Optional
from typing import List, Optional, Union
from official.core import config_definitions as cfg
from official.core import exp_factory
......@@ -29,26 +29,6 @@ from official.vision.beta.configs import backbones
# pylint: disable=missing-class-docstring
@dataclasses.dataclass
class TfExampleDecoder(hyperparams.Config):
regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
@dataclasses.dataclass
class TfExampleDecoderLabelMap(hyperparams.Config):
regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
label_map: str = ''
@dataclasses.dataclass
class DataDecoder(hyperparams.OneOfConfig):
type: Optional[str] = 'simple_decoder'
simple_decoder: TfExampleDecoder = TfExampleDecoder()
label_map_decoder: TfExampleDecoderLabelMap = TfExampleDecoderLabelMap()
@dataclasses.dataclass
class Parser(hyperparams.Config):
num_channels: int = 3
......@@ -73,7 +53,7 @@ class DataConfig(cfg.DataConfig):
global_batch_size: int = 0
is_training: bool = False
dtype: str = 'bfloat16'
decoder: DataDecoder = DataDecoder()
decoder: common.DataDecoder = common.DataDecoder()
parser: Parser = Parser()
shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord'
......@@ -152,6 +132,7 @@ class DetectionGenerator(hyperparams.Config):
nms_iou_threshold: float = 0.5
max_num_detections: int = 100
use_batched_nms: bool = False
use_cpu_nms: bool = False
@dataclasses.dataclass
......@@ -221,7 +202,8 @@ class MaskRCNNTask(cfg.TaskConfig):
drop_remainder=False)
losses: Losses = Losses()
init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone
init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone, and/or decoder
annotation_file: Optional[str] = None
per_category_metrics: bool = False
# If set, we only use masks for the specified class IDs.
......@@ -450,7 +432,7 @@ def maskrcnn_spinenet_coco() -> cfg.ExperimentConfig:
@exp_factory.register_config_factory('cascadercnn_spinenet_coco')
def cascadercnn_spinenet_coco() -> cfg.ExperimentConfig:
"""COCO object detection with Cascade R-CNN with SpineNet backbone."""
"""COCO object detection with Cascade RCNN-RS with SpineNet backbone."""
steps_per_epoch = 463
coco_val_samples = 5000
train_batch_size = 256
......
......@@ -15,9 +15,9 @@
# Lint as: python3
"""RetinaNet configuration definition."""
import os
from typing import List, Optional
import dataclasses
import os
from typing import List, Optional, Union
from official.core import config_definitions as cfg
from official.core import exp_factory
......@@ -29,22 +29,22 @@ from official.vision.beta.configs import backbones
# pylint: disable=missing-class-docstring
# Keep for backward compatibility.
@dataclasses.dataclass
class TfExampleDecoder(hyperparams.Config):
regenerate_source_id: bool = False
class TfExampleDecoder(common.TfExampleDecoder):
"""A simple TF Example decoder config."""
# Keep for backward compatibility.
@dataclasses.dataclass
class TfExampleDecoderLabelMap(hyperparams.Config):
regenerate_source_id: bool = False
label_map: str = ''
class TfExampleDecoderLabelMap(common.TfExampleDecoderLabelMap):
"""TF Example decoder with label map config."""
# Keep for backward compatibility.
@dataclasses.dataclass
class DataDecoder(hyperparams.OneOfConfig):
type: Optional[str] = 'simple_decoder'
simple_decoder: TfExampleDecoder = TfExampleDecoder()
label_map_decoder: TfExampleDecoderLabelMap = TfExampleDecoderLabelMap()
class DataDecoder(common.DataDecoder):
"""Data decoder config."""
@dataclasses.dataclass
......@@ -55,6 +55,7 @@ class Parser(hyperparams.Config):
aug_rand_hflip: bool = False
aug_scale_min: float = 1.0
aug_scale_max: float = 1.0
aug_policy: Optional[str] = None
skip_crowd_during_training: bool = True
max_num_instances: int = 100
......@@ -66,7 +67,7 @@ class DataConfig(cfg.DataConfig):
global_batch_size: int = 0
is_training: bool = False
dtype: str = 'bfloat16'
decoder: DataDecoder = DataDecoder()
decoder: common.DataDecoder = common.DataDecoder()
parser: Parser = Parser()
shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord'
......@@ -112,6 +113,7 @@ class DetectionGenerator(hyperparams.Config):
nms_iou_threshold: float = 0.5
max_num_detections: int = 100
use_batched_nms: bool = False
use_cpu_nms: bool = False
@dataclasses.dataclass
......@@ -144,7 +146,8 @@ class RetinaNetTask(cfg.TaskConfig):
validation_data: DataConfig = DataConfig(is_training=False)
losses: Losses = Losses()
init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone
init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone, and/or decoder
annotation_file: Optional[str] = None
per_category_metrics: bool = False
export_config: ExportConfig = ExportConfig()
......
......@@ -14,10 +14,10 @@
# Lint as: python3
"""Semantic segmentation configuration definition."""
import dataclasses
import os
from typing import List, Optional, Union
import dataclasses
import numpy as np
from official.core import exp_factory
......@@ -50,8 +50,10 @@ class DataConfig(cfg.DataConfig):
aug_scale_min: float = 1.0
aug_scale_max: float = 1.0
aug_rand_hflip: bool = True
aug_policy: Optional[str] = None
drop_remainder: bool = True
file_type: str = 'tfrecord'
decoder: Optional[common.DataDecoder] = common.DataDecoder()
@dataclasses.dataclass
......@@ -60,6 +62,7 @@ class SegmentationHead(hyperparams.Config):
level: int = 3
num_convs: int = 2
num_filters: int = 256
use_depthwise_convolution: bool = False
prediction_kernel_size: int = 1
upsample_factor: int = 1
feature_fusion: Optional[str] = None # None, deeplabv3plus, or pyramid_fusion
......@@ -119,7 +122,7 @@ class SemanticSegmentationTask(cfg.TaskConfig):
def semantic_segmentation() -> cfg.ExperimentConfig:
"""Semantic segmentation general."""
return cfg.ExperimentConfig(
task=SemanticSegmentationModel(),
task=SemanticSegmentationTask(),
trainer=cfg.TrainerConfig(),
restrictions=[
'task.train_data.is_training != None',
......
......@@ -58,6 +58,14 @@ flags.DEFINE_string(
'annotations - boxes and instance masks.')
flags.DEFINE_string('caption_annotations_file', '', 'File containing image '
'captions.')
flags.DEFINE_string('panoptic_annotations_file', '', 'File containing panoptic '
'annotations.')
flags.DEFINE_string('panoptic_masks_dir', '',
'Directory containing panoptic masks annotations.')
flags.DEFINE_boolean(
'include_panoptic_masks', False, 'Whether to include category and '
'instance masks in the result. These are required to run the PQ evaluator '
'default: False.')
flags.DEFINE_string('output_file_prefix', '/tmp/train', 'Path to output file')
flags.DEFINE_integer('num_shards', 32, 'Number of shards for output file.')
......@@ -66,6 +74,11 @@ FLAGS = flags.FLAGS
logger = tf.get_logger()
logger.setLevel(logging.INFO)
_VOID_LABEL = 0
_VOID_INSTANCE_ID = 0
_THING_CLASS_ID = 1
_STUFF_CLASSES_OFFSET = 90
def coco_segmentation_to_mask_png(segmentation, height, width, is_crowd):
"""Encode a COCO mask segmentation as PNG string."""
......@@ -74,12 +87,79 @@ def coco_segmentation_to_mask_png(segmentation, height, width, is_crowd):
if not is_crowd:
binary_mask = np.amax(binary_mask, axis=2)
return tfrecord_lib.encode_binary_mask_as_png(binary_mask)
return tfrecord_lib.encode_mask_as_png(binary_mask)
def generate_coco_panoptics_masks(segments_info, mask_path,
include_panoptic_masks,
is_category_thing):
"""Creates masks for panoptic segmentation task.
Args:
segments_info: a list of dicts, where each dict has keys: [u'id',
u'category_id', u'area', u'bbox', u'iscrowd'], detailing information for
each segment in the panoptic mask.
mask_path: path to the panoptic mask.
include_panoptic_masks: bool, when set to True, category and instance
masks are included in the outputs. Set this to True, when using
the Panoptic Quality evaluator.
is_category_thing: a dict with category ids as keys and, 0/1 as values to
represent "stuff" and "things" classes respectively.
Returns:
A dict with with keys: [u'semantic_segmentation_mask', u'category_mask',
u'instance_mask']. The dict contains 'category_mask' and 'instance_mask'
only if `include_panoptic_eval_masks` is set to True.
"""
rgb_mask = tfrecord_lib.read_image(mask_path)
r, g, b = np.split(rgb_mask, 3, axis=-1)
# decode rgb encoded panoptic mask to get segments ids
# refer https://cocodataset.org/#format-data
segments_encoded_mask = (r + g * 256 + b * (256**2)).squeeze()
semantic_segmentation_mask = np.ones_like(
segments_encoded_mask, dtype=np.uint8) * _VOID_LABEL
if include_panoptic_masks:
category_mask = np.ones_like(
segments_encoded_mask, dtype=np.uint8) * _VOID_LABEL
instance_mask = np.ones_like(
segments_encoded_mask, dtype=np.uint8) * _VOID_INSTANCE_ID
for idx, segment in enumerate(segments_info):
segment_id = segment['id']
category_id = segment['category_id']
if is_category_thing[category_id]:
encoded_category_id = _THING_CLASS_ID
instance_id = idx + 1
else:
encoded_category_id = category_id - _STUFF_CLASSES_OFFSET
instance_id = _VOID_INSTANCE_ID
segment_mask = (segments_encoded_mask == segment_id)
semantic_segmentation_mask[segment_mask] = encoded_category_id
if include_panoptic_masks:
category_mask[segment_mask] = category_id
instance_mask[segment_mask] = instance_id
outputs = {
'semantic_segmentation_mask': tfrecord_lib.encode_mask_as_png(
semantic_segmentation_mask)
}
if include_panoptic_masks:
outputs.update({
'category_mask': tfrecord_lib.encode_mask_as_png(category_mask),
'instance_mask': tfrecord_lib.encode_mask_as_png(instance_mask)
})
return outputs
def coco_annotations_to_lists(bbox_annotations, id_to_name_map,
image_height, image_width, include_masks):
"""Convert COCO annotations to feature lists."""
"""Converts COCO annotations to feature lists."""
data = dict((k, list()) for k in
['xmin', 'xmax', 'ymin', 'ymax', 'is_crowd',
......@@ -160,9 +240,13 @@ def encode_caption_annotations(caption_annotations):
def create_tf_example(image,
image_dirs,
panoptic_masks_dir=None,
bbox_annotations=None,
id_to_name_map=None,
caption_annotations=None,
panoptic_annotation=None,
is_category_thing=None,
include_panoptic_masks=False,
include_masks=False):
"""Converts image and annotations to a tf.Example proto.
......@@ -170,6 +254,7 @@ def create_tf_example(image,
image: dict with keys: [u'license', u'file_name', u'coco_url', u'height',
u'width', u'date_captured', u'flickr_url', u'id']
image_dirs: list of directories containing the image files.
panoptic_masks_dir: `str` of the panoptic masks directory.
bbox_annotations:
list of dicts with keys: [u'segmentation', u'area', u'iscrowd',
u'image_id', u'bbox', u'category_id', u'id'] Notice that bounding box
......@@ -182,6 +267,11 @@ def create_tf_example(image,
id_to_name_map: a dict mapping category IDs to string names.
caption_annotations:
list of dict with keys: [u'id', u'image_id', u'str'].
panoptic_annotation: dict with keys: [u'image_id', u'file_name',
u'segments_info']. Where the value for segments_info is a list of dicts,
with each dict containing information for a single segment in the mask.
is_category_thing: `bool`, whether it is a category thing.
include_panoptic_masks: `bool`, whether to include panoptic masks.
include_masks: Whether to include instance segmentations masks
(PNG encoded) in the result. default: False.
......@@ -234,6 +324,26 @@ def create_tf_example(image,
feature_dict.update(
{'image/caption': tfrecord_lib.convert_to_feature(encoded_captions)})
if panoptic_annotation:
segments_info = panoptic_annotation['segments_info']
panoptic_mask_filename = os.path.join(
panoptic_masks_dir,
panoptic_annotation['file_name'])
encoded_panoptic_masks = generate_coco_panoptics_masks(
segments_info, panoptic_mask_filename, include_panoptic_masks,
is_category_thing)
feature_dict.update(
{'image/segmentation/class/encoded': tfrecord_lib.convert_to_feature(
encoded_panoptic_masks['semantic_segmentation_mask'])})
if include_panoptic_masks:
feature_dict.update({
'image/panoptic/category_mask': tfrecord_lib.convert_to_feature(
encoded_panoptic_masks['category_mask']),
'image/panoptic/instance_mask': tfrecord_lib.convert_to_feature(
encoded_panoptic_masks['instance_mask'])
})
example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
return example, num_annotations_skipped
......@@ -287,6 +397,33 @@ def _load_caption_annotations(caption_annotations_file):
return img_to_caption_annotation
def _load_panoptic_annotations(panoptic_annotations_file):
"""Loads panoptic annotation from file."""
with tf.io.gfile.GFile(panoptic_annotations_file, 'r') as fid:
panoptic_annotations = json.load(fid)
img_to_panoptic_annotation = dict()
logging.info('Building panoptic index.')
for annotation in panoptic_annotations['annotations']:
image_id = annotation['image_id']
img_to_panoptic_annotation[image_id] = annotation
is_category_thing = dict()
for category_info in panoptic_annotations['categories']:
is_category_thing[category_info['id']] = category_info['isthing'] == 1
missing_annotation_count = 0
images = panoptic_annotations['images']
for image in images:
image_id = image['id']
if image_id not in img_to_panoptic_annotation:
missing_annotation_count += 1
logging.info(
'%d images are missing panoptic annotations.', missing_annotation_count)
return img_to_panoptic_annotation, is_category_thing
def _load_images_info(images_info_file):
with tf.io.gfile.GFile(images_info_file, 'r') as fid:
info_dict = json.load(fid)
......@@ -294,11 +431,15 @@ def _load_images_info(images_info_file):
def generate_annotations(images, image_dirs,
panoptic_masks_dir=None,
img_to_obj_annotation=None,
img_to_caption_annotation=None, id_to_name_map=None,
img_to_caption_annotation=None,
img_to_panoptic_annotation=None,
is_category_thing=None,
id_to_name_map=None,
include_panoptic_masks=False,
include_masks=False):
"""Generator for COCO annotations."""
for image in images:
object_annotation = (img_to_obj_annotation.get(image['id'], None) if
img_to_obj_annotation else None)
......@@ -306,8 +447,11 @@ def generate_annotations(images, image_dirs,
caption_annotaion = (img_to_caption_annotation.get(image['id'], None) if
img_to_caption_annotation else None)
yield (image, image_dirs, object_annotation, id_to_name_map,
caption_annotaion, include_masks)
panoptic_annotation = (img_to_panoptic_annotation.get(image['id'], None) if
img_to_panoptic_annotation else None)
yield (image, image_dirs, panoptic_masks_dir, object_annotation,
id_to_name_map, caption_annotaion, panoptic_annotation,
is_category_thing, include_panoptic_masks, include_masks)
def _create_tf_record_from_coco_annotations(images_info_file,
......@@ -316,6 +460,9 @@ def _create_tf_record_from_coco_annotations(images_info_file,
num_shards,
object_annotations_file=None,
caption_annotations_file=None,
panoptic_masks_dir=None,
panoptic_annotations_file=None,
include_panoptic_masks=False,
include_masks=False):
"""Loads COCO annotation json files and converts to tf.Record format.
......@@ -331,6 +478,10 @@ def _create_tf_record_from_coco_annotations(images_info_file,
num_shards: Number of output files to create.
object_annotations_file: JSON file containing bounding box annotations.
caption_annotations_file: JSON file containing caption annotations.
panoptic_masks_dir: Directory containing panoptic masks.
panoptic_annotations_file: JSON file containing panoptic annotations.
include_panoptic_masks: Whether to include 'category_mask'
and 'instance_mask', which is required by the panoptic quality evaluator.
include_masks: Whether to include instance segmentations masks
(PNG encoded) in the result. default: False.
"""
......@@ -342,16 +493,29 @@ def _create_tf_record_from_coco_annotations(images_info_file,
img_to_obj_annotation = None
img_to_caption_annotation = None
id_to_name_map = None
img_to_panoptic_annotation = None
is_category_thing = None
if object_annotations_file:
img_to_obj_annotation, id_to_name_map = (
_load_object_annotations(object_annotations_file))
if caption_annotations_file:
img_to_caption_annotation = (
_load_caption_annotations(caption_annotations_file))
if panoptic_annotations_file:
img_to_panoptic_annotation, is_category_thing = (
_load_panoptic_annotations(panoptic_annotations_file))
coco_annotations_iter = generate_annotations(
images, image_dirs, img_to_obj_annotation, img_to_caption_annotation,
id_to_name_map=id_to_name_map, include_masks=include_masks)
images=images,
image_dirs=image_dirs,
panoptic_masks_dir=panoptic_masks_dir,
img_to_obj_annotation=img_to_obj_annotation,
img_to_caption_annotation=img_to_caption_annotation,
img_to_panoptic_annotation=img_to_panoptic_annotation,
is_category_thing=is_category_thing,
id_to_name_map=id_to_name_map,
include_panoptic_masks=include_panoptic_masks,
include_masks=include_masks)
num_skipped = tfrecord_lib.write_tf_record_dataset(
output_path, coco_annotations_iter, create_tf_example, num_shards)
......@@ -380,6 +544,9 @@ def main(_):
FLAGS.num_shards,
FLAGS.object_annotations_file,
FLAGS.caption_annotations_file,
FLAGS.panoptic_masks_dir,
FLAGS.panoptic_annotations_file,
FLAGS.include_panoptic_masks,
FLAGS.include_masks)
......
......@@ -15,7 +15,7 @@ done
cocosplit_url="dl.yf.io/fs-det/datasets/cocosplit"
wget --recursive --no-parent -q --show-progress --progress=bar:force:noscroll \
-P "${tmp_dir}" -A "trainvalno5k.json,5k.json,*10shot*.json,*30shot*.json" \
-P "${tmp_dir}" -A "trainvalno5k.json,5k.json,*1shot*.json,*3shot*.json,*5shot*.json,*10shot*.json,*30shot*.json" \
"http://${cocosplit_url}/"
mv "${tmp_dir}/${cocosplit_url}/"* "${tmp_dir}"
rm -rf "${tmp_dir}/${cocosplit_url}/"
......@@ -24,7 +24,7 @@ python process_coco_few_shot_json_files.py \
--logtostderr --workdir="${tmp_dir}"
for seed in {0..9}; do
for shots in 10 30; do
for shots in 1 3 5 10 30; do
python create_coco_tf_record.py \
--logtostderr \
--image_dir="${base_image_dir}/train2014" \
......
......@@ -53,7 +53,7 @@ CATEGORIES = ['airplane', 'apple', 'backpack', 'banana', 'baseball bat',
'traffic light', 'train', 'truck', 'tv', 'umbrella', 'vase',
'wine glass', 'zebra']
SEEDS = list(range(10))
SHOTS = [10, 30]
SHOTS = [1, 3, 5, 10, 30]
FILE_SUFFIXES = collections.defaultdict(list)
for _seed, _shots in itertools.product(SEEDS, SHOTS):
......
......@@ -100,8 +100,13 @@ def image_info_to_feature_dict(height, width, filename, image_id,
}
def encode_binary_mask_as_png(binary_mask):
pil_image = Image.fromarray(binary_mask)
def read_image(image_path):
pil_image = Image.open(image_path)
return np.asarray(pil_image)
def encode_mask_as_png(mask):
pil_image = Image.fromarray(mask)
output_io = io.BytesIO()
pil_image.save(output_io, format='PNG')
return output_io.getvalue()
......
......@@ -38,7 +38,6 @@ class TfExampleDecoder(decoder.Decoder):
self._regenerate_source_id = regenerate_source_id
self._keys_to_features = {
'image/encoded': tf.io.FixedLenFeature((), tf.string),
'image/source_id': tf.io.FixedLenFeature((), tf.string),
'image/height': tf.io.FixedLenFeature((), tf.int64),
'image/width': tf.io.FixedLenFeature((), tf.int64),
'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
......@@ -54,6 +53,10 @@ class TfExampleDecoder(decoder.Decoder):
self._keys_to_features.update({
'image/object/mask': tf.io.VarLenFeature(tf.string),
})
if not regenerate_source_id:
self._keys_to_features.update({
'image/source_id': tf.io.FixedLenFeature((), tf.string),
})
def _decode_image(self, parsed_tensors):
"""Decodes the image and set its static shape."""
......
......@@ -131,7 +131,6 @@ def convert_predictions_to_coco_annotations(predictions):
"""
coco_predictions = []
num_batches = len(predictions['source_id'])
batch_size = predictions['source_id'][0].shape[0]
max_num_detections = predictions['detection_classes'][0].shape[1]
use_outer_box = 'detection_outer_boxes' in predictions
for i in range(num_batches):
......@@ -144,6 +143,7 @@ def convert_predictions_to_coco_annotations(predictions):
else:
mask_boxes = predictions['detection_boxes']
batch_size = predictions['source_id'][i].shape[0]
for j in range(batch_size):
if 'detection_masks' in predictions:
image_masks = mask_ops.paste_instance_masks(
......@@ -211,9 +211,9 @@ def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None):
gt_annotations = []
num_batches = len(groundtruths['source_id'])
batch_size = groundtruths['source_id'][0].shape[0]
for i in range(num_batches):
max_num_instances = groundtruths['classes'][i].shape[1]
batch_size = groundtruths['source_id'][i].shape[0]
for j in range(batch_size):
num_instances = groundtruths['num_detections'][i][j]
if num_instances > max_num_instances:
......
......@@ -342,9 +342,10 @@ Berkin Akin, Suyog Gupta, and Andrew Howard
"""
MNMultiMAX_BLOCK_SPECS = {
'spec_name': 'MobileNetMultiMAX',
'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
'activation', 'expand_ratio',
'use_normalization', 'use_bias', 'is_output'],
'block_spec_schema': [
'block_fn', 'kernel_size', 'strides', 'filters', 'activation',
'expand_ratio', 'use_normalization', 'use_bias', 'is_output'
],
'block_specs': [
('convbn', 3, 2, 32, 'relu', None, True, False, False),
('invertedbottleneck', 3, 2, 32, 'relu', 3., None, False, True),
......@@ -363,15 +364,18 @@ MNMultiMAX_BLOCK_SPECS = {
('invertedbottleneck', 5, 1, 160, 'relu', 4., None, False, True),
('convbn', 1, 1, 960, 'relu', None, True, False, False),
('gpooling', None, None, None, None, None, None, None, False),
('convbn', 1, 1, 1280, 'relu', None, False, True, False),
# Remove bias and add batch norm for the last layer to support QAT
# and achieve slightly better accuracy.
('convbn', 1, 1, 1280, 'relu', None, True, False, False),
]
}
MNMultiAVG_BLOCK_SPECS = {
'spec_name': 'MobileNetMultiAVG',
'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
'activation', 'expand_ratio',
'use_normalization', 'use_bias', 'is_output'],
'block_spec_schema': [
'block_fn', 'kernel_size', 'strides', 'filters', 'activation',
'expand_ratio', 'use_normalization', 'use_bias', 'is_output'
],
'block_specs': [
('convbn', 3, 2, 32, 'relu', None, True, False, False),
('invertedbottleneck', 3, 2, 32, 'relu', 3., None, False, False),
......@@ -392,7 +396,9 @@ MNMultiAVG_BLOCK_SPECS = {
('invertedbottleneck', 5, 1, 192, 'relu', 4., None, False, True),
('convbn', 1, 1, 960, 'relu', None, True, False, False),
('gpooling', None, None, None, None, None, None, None, False),
('convbn', 1, 1, 1280, 'relu', None, False, True, False),
# Remove bias and add batch norm for the last layer to support QAT
# and achieve slightly better accuracy.
('convbn', 1, 1, 1280, 'relu', None, True, False, False),
]
}
......
......@@ -158,10 +158,10 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
('MobileNetV3Small', 0.75): 1026552,
('MobileNetV3EdgeTPU', 1.0): 2849312,
('MobileNetV3EdgeTPU', 0.75): 1737288,
('MobileNetMultiAVG', 1.0): 3700576,
('MobileNetMultiAVG', 0.75): 2345864,
('MobileNetMultiMAX', 1.0): 3170720,
('MobileNetMultiMAX', 0.75): 2041976,
('MobileNetMultiAVG', 1.0): 3704416,
('MobileNetMultiAVG', 0.75): 2349704,
('MobileNetMultiMAX', 1.0): 3174560,
('MobileNetMultiMAX', 0.75): 2045816,
}
input_size = 224
......
......@@ -32,6 +32,12 @@ layers = tf.keras.layers
# Each element in the block configuration is in the following format:
# (block_fn, num_filters, block_repeats)
RESNET_SPECS = {
10: [
('residual', 64, 1),
('residual', 128, 1),
('residual', 256, 1),
('residual', 512, 1),
],
18: [
('residual', 64, 2),
('residual', 128, 2),
......@@ -114,6 +120,7 @@ class ResNet(tf.keras.Model):
replace_stem_max_pool: bool = False,
se_ratio: Optional[float] = None,
init_stochastic_depth_rate: float = 0.0,
scale_stem: bool = True,
activation: str = 'relu',
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
......@@ -121,6 +128,7 @@ class ResNet(tf.keras.Model):
kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bn_trainable: bool = True,
**kwargs):
"""Initializes a ResNet model.
......@@ -138,6 +146,7 @@ class ResNet(tf.keras.Model):
with a stride-2 conv,
se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
scale_stem: A `bool` of whether to scale stem layers.
activation: A `str` name of the activation function.
use_sync_bn: If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average.
......@@ -147,6 +156,8 @@ class ResNet(tf.keras.Model):
Conv2D. Default to None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
Default to None.
bn_trainable: A `bool` that indicates whether batch norm layers should be
trainable. Default to True.
**kwargs: Additional keyword arguments to be passed.
"""
self._model_id = model_id
......@@ -157,6 +168,7 @@ class ResNet(tf.keras.Model):
self._replace_stem_max_pool = replace_stem_max_pool
self._se_ratio = se_ratio
self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._scale_stem = scale_stem
self._use_sync_bn = use_sync_bn
self._activation = activation
self._norm_momentum = norm_momentum
......@@ -168,6 +180,7 @@ class ResNet(tf.keras.Model):
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._bn_trainable = bn_trainable
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
......@@ -177,9 +190,10 @@ class ResNet(tf.keras.Model):
# Build ResNet.
inputs = tf.keras.Input(shape=input_specs.shape[1:])
stem_depth_multiplier = self._depth_multiplier if scale_stem else 1.0
if stem_type == 'v0':
x = layers.Conv2D(
filters=int(64 * self._depth_multiplier),
filters=int(64 * stem_depth_multiplier),
kernel_size=7,
strides=2,
use_bias=False,
......@@ -189,12 +203,15 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
elif stem_type == 'v1':
x = layers.Conv2D(
filters=int(32 * self._depth_multiplier),
filters=int(32 * stem_depth_multiplier),
kernel_size=3,
strides=2,
use_bias=False,
......@@ -204,11 +221,14 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
x = layers.Conv2D(
filters=int(32 * self._depth_multiplier),
filters=int(32 * stem_depth_multiplier),
kernel_size=3,
strides=1,
use_bias=False,
......@@ -218,11 +238,14 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
x = layers.Conv2D(
filters=int(64 * self._depth_multiplier),
filters=int(64 * stem_depth_multiplier),
kernel_size=3,
strides=1,
use_bias=False,
......@@ -232,7 +255,10 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
else:
......@@ -250,7 +276,10 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
else:
......@@ -318,7 +347,8 @@ class ResNet(tf.keras.Model):
activation=self._activation,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
norm_epsilon=self._norm_epsilon,
bn_trainable=self._bn_trainable)(
inputs)
for _ in range(1, block_repeats):
......@@ -335,7 +365,8 @@ class ResNet(tf.keras.Model):
activation=self._activation,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
norm_epsilon=self._norm_epsilon,
bn_trainable=self._bn_trainable)(
x)
return tf.keras.layers.Activation('linear', name=name)(x)
......@@ -350,12 +381,14 @@ class ResNet(tf.keras.Model):
'activation': self._activation,
'se_ratio': self._se_ratio,
'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
'scale_stem': self._scale_stem,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'bn_trainable': self._bn_trainable
}
return config_dict
......@@ -390,8 +423,10 @@ def build_resnet(
replace_stem_max_pool=backbone_cfg.replace_stem_max_pool,
se_ratio=backbone_cfg.se_ratio,
init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
scale_stem=backbone_cfg.scale_stem,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
kernel_regularizer=l2_regularizer,
bn_trainable=backbone_cfg.bn_trainable)
......@@ -28,6 +28,7 @@ from official.vision.beta.modeling.backbones import resnet
class ResNetTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(128, 10, 1),
(128, 18, 1),
(128, 34, 1),
(128, 50, 4),
......@@ -38,6 +39,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
endpoint_filter_scale):
"""Test creation of ResNet family models."""
resnet_params = {
10: 4915904,
18: 11190464,
34: 21306048,
50: 23561152,
......@@ -126,6 +128,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
resnetd_shortcut=False,
replace_stem_max_pool=False,
init_stochastic_depth_rate=0.0,
scale_stem=True,
use_sync_bn=False,
activation='relu',
norm_momentum=0.99,
......@@ -133,7 +136,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
)
bn_trainable=True)
network = resnet.ResNet(**kwargs)
expected_config = dict(kwargs)
......
......@@ -93,23 +93,6 @@ class ClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase):
def test_mobilenet_network_creation(self, mobilenet_model_id,
filter_size_scale):
"""Test for creation of a MobileNet classifier."""
mobilenet_params = {
('MobileNetV1', 1.0): 4254889,
('MobileNetV1', 0.75): 2602745,
('MobileNetV2', 1.0): 3540265,
('MobileNetV2', 0.75): 2664345,
('MobileNetV3Large', 1.0): 5508713,
('MobileNetV3Large', 0.75): 4013897,
('MobileNetV3Small', 1.0): 2555993,
('MobileNetV3Small', 0.75): 2052577,
('MobileNetV3EdgeTPU', 1.0): 4131593,
('MobileNetV3EdgeTPU', 0.75): 3019569,
('MobileNetMultiAVG', 1.0): 4982857,
('MobileNetMultiAVG', 0.75): 3628145,
('MobileNetMultiMAX', 1.0): 4453001,
('MobileNetMultiMAX', 0.75): 3324257,
}
inputs = np.random.rand(2, 224, 224, 3)
tf.keras.backend.set_image_data_format('channels_last')
......@@ -123,8 +106,6 @@ class ClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase):
num_classes=num_classes,
dropout_rate=0.2,
)
self.assertEqual(model.count_params(),
mobilenet_params[(mobilenet_model_id, filter_size_scale)])
logits = model(inputs)
self.assertAllEqual([2, num_classes], logits.numpy().shape)
......
......@@ -42,6 +42,7 @@ class ASPP(tf.keras.layers.Layer):
kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
interpolation: str = 'bilinear',
use_depthwise_convolution: bool = False,
**kwargs):
"""Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer.
......@@ -64,6 +65,8 @@ class ASPP(tf.keras.layers.Layer):
interpolation: A `str` of interpolation method. It should be one of
`bilinear`, `nearest`, `bicubic`, `area`, `lanczos3`, `lanczos5`,
`gaussian`, or `mitchellcubic`.
use_depthwise_convolution: If True depthwise separable convolutions will
be added to the Atrous spatial pyramid pooling.
**kwargs: Additional keyword arguments to be passed.
"""
super(ASPP, self).__init__(**kwargs)
......@@ -80,6 +83,7 @@ class ASPP(tf.keras.layers.Layer):
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'interpolation': interpolation,
'use_depthwise_convolution': use_depthwise_convolution,
}
def build(self, input_shape):
......@@ -100,7 +104,9 @@ class ASPP(tf.keras.layers.Layer):
dropout=self._config_dict['dropout_rate'],
kernel_initializer=self._config_dict['kernel_initializer'],
kernel_regularizer=self._config_dict['kernel_regularizer'],
interpolation=self._config_dict['interpolation'])
interpolation=self._config_dict['interpolation'],
use_depthwise_convolution=self._config_dict['use_depthwise_convolution']
)
def call(self, inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
"""Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input.
......@@ -167,6 +173,7 @@ def build_aspp_decoder(
level=decoder_cfg.level,
dilation_rates=decoder_cfg.dilation_rates,
num_filters=decoder_cfg.num_filters,
use_depthwise_convolution=decoder_cfg.use_depthwise_convolution,
pool_kernel_size=decoder_cfg.pool_kernel_size,
dropout_rate=decoder_cfg.dropout_rate,
use_sync_bn=norm_activation_config.use_sync_bn,
......
......@@ -70,6 +70,7 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase):
kernel_regularizer=None,
interpolation='bilinear',
dropout_rate=0.2,
use_depthwise_convolution='false',
)
network = aspp.ASPP(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment