Unverified Commit ca552843 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-segmentation

parents 7e2f7a35 6b90e134
...@@ -30,8 +30,10 @@ class ResNet(hyperparams.Config): ...@@ -30,8 +30,10 @@ class ResNet(hyperparams.Config):
stem_type: str = 'v0' stem_type: str = 'v0'
se_ratio: float = 0.0 se_ratio: float = 0.0
stochastic_depth_drop_rate: float = 0.0 stochastic_depth_drop_rate: float = 0.0
scale_stem: bool = True
resnetd_shortcut: bool = False resnetd_shortcut: bool = False
replace_stem_max_pool: bool = False replace_stem_max_pool: bool = False
bn_trainable: bool = True
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -15,15 +15,44 @@ ...@@ -15,15 +15,44 @@
# Lint as: python3 # Lint as: python3
"""Common configurations.""" """Common configurations."""
import dataclasses
from typing import Optional from typing import Optional
# Import libraries
import dataclasses # Import libraries
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.modeling import hyperparams from official.modeling import hyperparams
@dataclasses.dataclass
class TfExampleDecoder(hyperparams.Config):
"""A simple TF Example decoder config."""
regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
@dataclasses.dataclass
class TfExampleDecoderLabelMap(hyperparams.Config):
"""TF Example decoder with label map config."""
regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
label_map: str = ''
@dataclasses.dataclass
class DataDecoder(hyperparams.OneOfConfig):
"""Data decoder config.
Attributes:
type: 'str', type of data decoder be used, one of the fields below.
simple_decoder: simple TF Example decoder config.
label_map_decoder: TF Example decoder with label map config.
"""
type: Optional[str] = 'simple_decoder'
simple_decoder: TfExampleDecoder = TfExampleDecoder()
label_map_decoder: TfExampleDecoderLabelMap = TfExampleDecoderLabelMap()
@dataclasses.dataclass @dataclasses.dataclass
class RandAugment(hyperparams.Config): class RandAugment(hyperparams.Config):
"""Configuration for RandAugment.""" """Configuration for RandAugment."""
......
...@@ -32,6 +32,7 @@ class Identity(hyperparams.Config): ...@@ -32,6 +32,7 @@ class Identity(hyperparams.Config):
class FPN(hyperparams.Config): class FPN(hyperparams.Config):
"""FPN config.""" """FPN config."""
num_filters: int = 256 num_filters: int = 256
fusion_type: str = 'sum'
use_separable_conv: bool = False use_separable_conv: bool = False
...@@ -50,6 +51,7 @@ class ASPP(hyperparams.Config): ...@@ -50,6 +51,7 @@ class ASPP(hyperparams.Config):
dilation_rates: List[int] = dataclasses.field(default_factory=list) dilation_rates: List[int] = dataclasses.field(default_factory=list)
dropout_rate: float = 0.0 dropout_rate: float = 0.0
num_filters: int = 256 num_filters: int = 256
use_depthwise_convolution: bool = False
pool_kernel_size: Optional[List[int]] = None # Use global average pooling. pool_kernel_size: Optional[List[int]] = None # Use global average pooling.
......
...@@ -14,11 +14,10 @@ ...@@ -14,11 +14,10 @@
# Lint as: python3 # Lint as: python3
"""Image classification configuration definition.""" """Image classification configuration definition."""
import dataclasses
import os import os
from typing import List, Optional from typing import List, Optional
import dataclasses
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
...@@ -44,6 +43,7 @@ class DataConfig(cfg.DataConfig): ...@@ -44,6 +43,7 @@ class DataConfig(cfg.DataConfig):
image_field_key: str = 'image/encoded' image_field_key: str = 'image/encoded'
label_field_key: str = 'image/class/label' label_field_key: str = 'image/class/label'
decode_jpeg_only: bool = True decode_jpeg_only: bool = True
decoder: Optional[common.DataDecoder] = common.DataDecoder()
# Keep for backward compatibility. # Keep for backward compatibility.
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'. aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'.
......
...@@ -13,11 +13,11 @@ ...@@ -13,11 +13,11 @@
# limitations under the License. # limitations under the License.
# Lint as: python3 # Lint as: python3
"""Mask R-CNN configuration definition.""" """R-CNN(-RS) configuration definition."""
import dataclasses import dataclasses
import os import os
from typing import List, Optional from typing import List, Optional, Union
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
...@@ -29,26 +29,6 @@ from official.vision.beta.configs import backbones ...@@ -29,26 +29,6 @@ from official.vision.beta.configs import backbones
# pylint: disable=missing-class-docstring # pylint: disable=missing-class-docstring
@dataclasses.dataclass
class TfExampleDecoder(hyperparams.Config):
regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
@dataclasses.dataclass
class TfExampleDecoderLabelMap(hyperparams.Config):
regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
label_map: str = ''
@dataclasses.dataclass
class DataDecoder(hyperparams.OneOfConfig):
type: Optional[str] = 'simple_decoder'
simple_decoder: TfExampleDecoder = TfExampleDecoder()
label_map_decoder: TfExampleDecoderLabelMap = TfExampleDecoderLabelMap()
@dataclasses.dataclass @dataclasses.dataclass
class Parser(hyperparams.Config): class Parser(hyperparams.Config):
num_channels: int = 3 num_channels: int = 3
...@@ -73,7 +53,7 @@ class DataConfig(cfg.DataConfig): ...@@ -73,7 +53,7 @@ class DataConfig(cfg.DataConfig):
global_batch_size: int = 0 global_batch_size: int = 0
is_training: bool = False is_training: bool = False
dtype: str = 'bfloat16' dtype: str = 'bfloat16'
decoder: DataDecoder = DataDecoder() decoder: common.DataDecoder = common.DataDecoder()
parser: Parser = Parser() parser: Parser = Parser()
shuffle_buffer_size: int = 10000 shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord' file_type: str = 'tfrecord'
...@@ -152,6 +132,7 @@ class DetectionGenerator(hyperparams.Config): ...@@ -152,6 +132,7 @@ class DetectionGenerator(hyperparams.Config):
nms_iou_threshold: float = 0.5 nms_iou_threshold: float = 0.5
max_num_detections: int = 100 max_num_detections: int = 100
use_batched_nms: bool = False use_batched_nms: bool = False
use_cpu_nms: bool = False
@dataclasses.dataclass @dataclasses.dataclass
...@@ -221,7 +202,8 @@ class MaskRCNNTask(cfg.TaskConfig): ...@@ -221,7 +202,8 @@ class MaskRCNNTask(cfg.TaskConfig):
drop_remainder=False) drop_remainder=False)
losses: Losses = Losses() losses: Losses = Losses()
init_checkpoint: Optional[str] = None init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone, and/or decoder
annotation_file: Optional[str] = None annotation_file: Optional[str] = None
per_category_metrics: bool = False per_category_metrics: bool = False
# If set, we only use masks for the specified class IDs. # If set, we only use masks for the specified class IDs.
...@@ -450,7 +432,7 @@ def maskrcnn_spinenet_coco() -> cfg.ExperimentConfig: ...@@ -450,7 +432,7 @@ def maskrcnn_spinenet_coco() -> cfg.ExperimentConfig:
@exp_factory.register_config_factory('cascadercnn_spinenet_coco') @exp_factory.register_config_factory('cascadercnn_spinenet_coco')
def cascadercnn_spinenet_coco() -> cfg.ExperimentConfig: def cascadercnn_spinenet_coco() -> cfg.ExperimentConfig:
"""COCO object detection with Cascade R-CNN with SpineNet backbone.""" """COCO object detection with Cascade RCNN-RS with SpineNet backbone."""
steps_per_epoch = 463 steps_per_epoch = 463
coco_val_samples = 5000 coco_val_samples = 5000
train_batch_size = 256 train_batch_size = 256
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
# Lint as: python3 # Lint as: python3
"""RetinaNet configuration definition.""" """RetinaNet configuration definition."""
import os
from typing import List, Optional
import dataclasses import dataclasses
import os
from typing import List, Optional, Union
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
...@@ -29,22 +29,22 @@ from official.vision.beta.configs import backbones ...@@ -29,22 +29,22 @@ from official.vision.beta.configs import backbones
# pylint: disable=missing-class-docstring # pylint: disable=missing-class-docstring
# Keep for backward compatibility.
@dataclasses.dataclass @dataclasses.dataclass
class TfExampleDecoder(hyperparams.Config): class TfExampleDecoder(common.TfExampleDecoder):
regenerate_source_id: bool = False """A simple TF Example decoder config."""
# Keep for backward compatibility.
@dataclasses.dataclass @dataclasses.dataclass
class TfExampleDecoderLabelMap(hyperparams.Config): class TfExampleDecoderLabelMap(common.TfExampleDecoderLabelMap):
regenerate_source_id: bool = False """TF Example decoder with label map config."""
label_map: str = ''
# Keep for backward compatibility.
@dataclasses.dataclass @dataclasses.dataclass
class DataDecoder(hyperparams.OneOfConfig): class DataDecoder(common.DataDecoder):
type: Optional[str] = 'simple_decoder' """Data decoder config."""
simple_decoder: TfExampleDecoder = TfExampleDecoder()
label_map_decoder: TfExampleDecoderLabelMap = TfExampleDecoderLabelMap()
@dataclasses.dataclass @dataclasses.dataclass
...@@ -55,6 +55,7 @@ class Parser(hyperparams.Config): ...@@ -55,6 +55,7 @@ class Parser(hyperparams.Config):
aug_rand_hflip: bool = False aug_rand_hflip: bool = False
aug_scale_min: float = 1.0 aug_scale_min: float = 1.0
aug_scale_max: float = 1.0 aug_scale_max: float = 1.0
aug_policy: Optional[str] = None
skip_crowd_during_training: bool = True skip_crowd_during_training: bool = True
max_num_instances: int = 100 max_num_instances: int = 100
...@@ -66,7 +67,7 @@ class DataConfig(cfg.DataConfig): ...@@ -66,7 +67,7 @@ class DataConfig(cfg.DataConfig):
global_batch_size: int = 0 global_batch_size: int = 0
is_training: bool = False is_training: bool = False
dtype: str = 'bfloat16' dtype: str = 'bfloat16'
decoder: DataDecoder = DataDecoder() decoder: common.DataDecoder = common.DataDecoder()
parser: Parser = Parser() parser: Parser = Parser()
shuffle_buffer_size: int = 10000 shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord' file_type: str = 'tfrecord'
...@@ -112,6 +113,7 @@ class DetectionGenerator(hyperparams.Config): ...@@ -112,6 +113,7 @@ class DetectionGenerator(hyperparams.Config):
nms_iou_threshold: float = 0.5 nms_iou_threshold: float = 0.5
max_num_detections: int = 100 max_num_detections: int = 100
use_batched_nms: bool = False use_batched_nms: bool = False
use_cpu_nms: bool = False
@dataclasses.dataclass @dataclasses.dataclass
...@@ -144,7 +146,8 @@ class RetinaNetTask(cfg.TaskConfig): ...@@ -144,7 +146,8 @@ class RetinaNetTask(cfg.TaskConfig):
validation_data: DataConfig = DataConfig(is_training=False) validation_data: DataConfig = DataConfig(is_training=False)
losses: Losses = Losses() losses: Losses = Losses()
init_checkpoint: Optional[str] = None init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone, and/or decoder
annotation_file: Optional[str] = None annotation_file: Optional[str] = None
per_category_metrics: bool = False per_category_metrics: bool = False
export_config: ExportConfig = ExportConfig() export_config: ExportConfig = ExportConfig()
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
# Lint as: python3 # Lint as: python3
"""Semantic segmentation configuration definition.""" """Semantic segmentation configuration definition."""
import dataclasses
import os import os
from typing import List, Optional, Union from typing import List, Optional, Union
import dataclasses
import numpy as np import numpy as np
from official.core import exp_factory from official.core import exp_factory
...@@ -50,8 +50,10 @@ class DataConfig(cfg.DataConfig): ...@@ -50,8 +50,10 @@ class DataConfig(cfg.DataConfig):
aug_scale_min: float = 1.0 aug_scale_min: float = 1.0
aug_scale_max: float = 1.0 aug_scale_max: float = 1.0
aug_rand_hflip: bool = True aug_rand_hflip: bool = True
aug_policy: Optional[str] = None
drop_remainder: bool = True drop_remainder: bool = True
file_type: str = 'tfrecord' file_type: str = 'tfrecord'
decoder: Optional[common.DataDecoder] = common.DataDecoder()
@dataclasses.dataclass @dataclasses.dataclass
...@@ -60,6 +62,7 @@ class SegmentationHead(hyperparams.Config): ...@@ -60,6 +62,7 @@ class SegmentationHead(hyperparams.Config):
level: int = 3 level: int = 3
num_convs: int = 2 num_convs: int = 2
num_filters: int = 256 num_filters: int = 256
use_depthwise_convolution: bool = False
prediction_kernel_size: int = 1 prediction_kernel_size: int = 1
upsample_factor: int = 1 upsample_factor: int = 1
feature_fusion: Optional[str] = None # None, deeplabv3plus, or pyramid_fusion feature_fusion: Optional[str] = None # None, deeplabv3plus, or pyramid_fusion
...@@ -119,7 +122,7 @@ class SemanticSegmentationTask(cfg.TaskConfig): ...@@ -119,7 +122,7 @@ class SemanticSegmentationTask(cfg.TaskConfig):
def semantic_segmentation() -> cfg.ExperimentConfig: def semantic_segmentation() -> cfg.ExperimentConfig:
"""Semantic segmentation general.""" """Semantic segmentation general."""
return cfg.ExperimentConfig( return cfg.ExperimentConfig(
task=SemanticSegmentationModel(), task=SemanticSegmentationTask(),
trainer=cfg.TrainerConfig(), trainer=cfg.TrainerConfig(),
restrictions=[ restrictions=[
'task.train_data.is_training != None', 'task.train_data.is_training != None',
......
...@@ -58,6 +58,14 @@ flags.DEFINE_string( ...@@ -58,6 +58,14 @@ flags.DEFINE_string(
'annotations - boxes and instance masks.') 'annotations - boxes and instance masks.')
flags.DEFINE_string('caption_annotations_file', '', 'File containing image ' flags.DEFINE_string('caption_annotations_file', '', 'File containing image '
'captions.') 'captions.')
flags.DEFINE_string('panoptic_annotations_file', '', 'File containing panoptic '
'annotations.')
flags.DEFINE_string('panoptic_masks_dir', '',
'Directory containing panoptic masks annotations.')
flags.DEFINE_boolean(
'include_panoptic_masks', False, 'Whether to include category and '
'instance masks in the result. These are required to run the PQ evaluator '
'default: False.')
flags.DEFINE_string('output_file_prefix', '/tmp/train', 'Path to output file') flags.DEFINE_string('output_file_prefix', '/tmp/train', 'Path to output file')
flags.DEFINE_integer('num_shards', 32, 'Number of shards for output file.') flags.DEFINE_integer('num_shards', 32, 'Number of shards for output file.')
...@@ -66,6 +74,11 @@ FLAGS = flags.FLAGS ...@@ -66,6 +74,11 @@ FLAGS = flags.FLAGS
logger = tf.get_logger() logger = tf.get_logger()
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
_VOID_LABEL = 0
_VOID_INSTANCE_ID = 0
_THING_CLASS_ID = 1
_STUFF_CLASSES_OFFSET = 90
def coco_segmentation_to_mask_png(segmentation, height, width, is_crowd): def coco_segmentation_to_mask_png(segmentation, height, width, is_crowd):
"""Encode a COCO mask segmentation as PNG string.""" """Encode a COCO mask segmentation as PNG string."""
...@@ -74,12 +87,79 @@ def coco_segmentation_to_mask_png(segmentation, height, width, is_crowd): ...@@ -74,12 +87,79 @@ def coco_segmentation_to_mask_png(segmentation, height, width, is_crowd):
if not is_crowd: if not is_crowd:
binary_mask = np.amax(binary_mask, axis=2) binary_mask = np.amax(binary_mask, axis=2)
return tfrecord_lib.encode_binary_mask_as_png(binary_mask) return tfrecord_lib.encode_mask_as_png(binary_mask)
def generate_coco_panoptics_masks(segments_info, mask_path,
include_panoptic_masks,
is_category_thing):
"""Creates masks for panoptic segmentation task.
Args:
segments_info: a list of dicts, where each dict has keys: [u'id',
u'category_id', u'area', u'bbox', u'iscrowd'], detailing information for
each segment in the panoptic mask.
mask_path: path to the panoptic mask.
include_panoptic_masks: bool, when set to True, category and instance
masks are included in the outputs. Set this to True, when using
the Panoptic Quality evaluator.
is_category_thing: a dict with category ids as keys and, 0/1 as values to
represent "stuff" and "things" classes respectively.
Returns:
A dict with with keys: [u'semantic_segmentation_mask', u'category_mask',
u'instance_mask']. The dict contains 'category_mask' and 'instance_mask'
only if `include_panoptic_eval_masks` is set to True.
"""
rgb_mask = tfrecord_lib.read_image(mask_path)
r, g, b = np.split(rgb_mask, 3, axis=-1)
# decode rgb encoded panoptic mask to get segments ids
# refer https://cocodataset.org/#format-data
segments_encoded_mask = (r + g * 256 + b * (256**2)).squeeze()
semantic_segmentation_mask = np.ones_like(
segments_encoded_mask, dtype=np.uint8) * _VOID_LABEL
if include_panoptic_masks:
category_mask = np.ones_like(
segments_encoded_mask, dtype=np.uint8) * _VOID_LABEL
instance_mask = np.ones_like(
segments_encoded_mask, dtype=np.uint8) * _VOID_INSTANCE_ID
for idx, segment in enumerate(segments_info):
segment_id = segment['id']
category_id = segment['category_id']
if is_category_thing[category_id]:
encoded_category_id = _THING_CLASS_ID
instance_id = idx + 1
else:
encoded_category_id = category_id - _STUFF_CLASSES_OFFSET
instance_id = _VOID_INSTANCE_ID
segment_mask = (segments_encoded_mask == segment_id)
semantic_segmentation_mask[segment_mask] = encoded_category_id
if include_panoptic_masks:
category_mask[segment_mask] = category_id
instance_mask[segment_mask] = instance_id
outputs = {
'semantic_segmentation_mask': tfrecord_lib.encode_mask_as_png(
semantic_segmentation_mask)
}
if include_panoptic_masks:
outputs.update({
'category_mask': tfrecord_lib.encode_mask_as_png(category_mask),
'instance_mask': tfrecord_lib.encode_mask_as_png(instance_mask)
})
return outputs
def coco_annotations_to_lists(bbox_annotations, id_to_name_map, def coco_annotations_to_lists(bbox_annotations, id_to_name_map,
image_height, image_width, include_masks): image_height, image_width, include_masks):
"""Convert COCO annotations to feature lists.""" """Converts COCO annotations to feature lists."""
data = dict((k, list()) for k in data = dict((k, list()) for k in
['xmin', 'xmax', 'ymin', 'ymax', 'is_crowd', ['xmin', 'xmax', 'ymin', 'ymax', 'is_crowd',
...@@ -160,9 +240,13 @@ def encode_caption_annotations(caption_annotations): ...@@ -160,9 +240,13 @@ def encode_caption_annotations(caption_annotations):
def create_tf_example(image, def create_tf_example(image,
image_dirs, image_dirs,
panoptic_masks_dir=None,
bbox_annotations=None, bbox_annotations=None,
id_to_name_map=None, id_to_name_map=None,
caption_annotations=None, caption_annotations=None,
panoptic_annotation=None,
is_category_thing=None,
include_panoptic_masks=False,
include_masks=False): include_masks=False):
"""Converts image and annotations to a tf.Example proto. """Converts image and annotations to a tf.Example proto.
...@@ -170,6 +254,7 @@ def create_tf_example(image, ...@@ -170,6 +254,7 @@ def create_tf_example(image,
image: dict with keys: [u'license', u'file_name', u'coco_url', u'height', image: dict with keys: [u'license', u'file_name', u'coco_url', u'height',
u'width', u'date_captured', u'flickr_url', u'id'] u'width', u'date_captured', u'flickr_url', u'id']
image_dirs: list of directories containing the image files. image_dirs: list of directories containing the image files.
panoptic_masks_dir: `str` of the panoptic masks directory.
bbox_annotations: bbox_annotations:
list of dicts with keys: [u'segmentation', u'area', u'iscrowd', list of dicts with keys: [u'segmentation', u'area', u'iscrowd',
u'image_id', u'bbox', u'category_id', u'id'] Notice that bounding box u'image_id', u'bbox', u'category_id', u'id'] Notice that bounding box
...@@ -182,6 +267,11 @@ def create_tf_example(image, ...@@ -182,6 +267,11 @@ def create_tf_example(image,
id_to_name_map: a dict mapping category IDs to string names. id_to_name_map: a dict mapping category IDs to string names.
caption_annotations: caption_annotations:
list of dict with keys: [u'id', u'image_id', u'str']. list of dict with keys: [u'id', u'image_id', u'str'].
panoptic_annotation: dict with keys: [u'image_id', u'file_name',
u'segments_info']. Where the value for segments_info is a list of dicts,
with each dict containing information for a single segment in the mask.
is_category_thing: `bool`, whether it is a category thing.
include_panoptic_masks: `bool`, whether to include panoptic masks.
include_masks: Whether to include instance segmentations masks include_masks: Whether to include instance segmentations masks
(PNG encoded) in the result. default: False. (PNG encoded) in the result. default: False.
...@@ -234,6 +324,26 @@ def create_tf_example(image, ...@@ -234,6 +324,26 @@ def create_tf_example(image,
feature_dict.update( feature_dict.update(
{'image/caption': tfrecord_lib.convert_to_feature(encoded_captions)}) {'image/caption': tfrecord_lib.convert_to_feature(encoded_captions)})
if panoptic_annotation:
segments_info = panoptic_annotation['segments_info']
panoptic_mask_filename = os.path.join(
panoptic_masks_dir,
panoptic_annotation['file_name'])
encoded_panoptic_masks = generate_coco_panoptics_masks(
segments_info, panoptic_mask_filename, include_panoptic_masks,
is_category_thing)
feature_dict.update(
{'image/segmentation/class/encoded': tfrecord_lib.convert_to_feature(
encoded_panoptic_masks['semantic_segmentation_mask'])})
if include_panoptic_masks:
feature_dict.update({
'image/panoptic/category_mask': tfrecord_lib.convert_to_feature(
encoded_panoptic_masks['category_mask']),
'image/panoptic/instance_mask': tfrecord_lib.convert_to_feature(
encoded_panoptic_masks['instance_mask'])
})
example = tf.train.Example(features=tf.train.Features(feature=feature_dict)) example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
return example, num_annotations_skipped return example, num_annotations_skipped
...@@ -287,6 +397,33 @@ def _load_caption_annotations(caption_annotations_file): ...@@ -287,6 +397,33 @@ def _load_caption_annotations(caption_annotations_file):
return img_to_caption_annotation return img_to_caption_annotation
def _load_panoptic_annotations(panoptic_annotations_file):
"""Loads panoptic annotation from file."""
with tf.io.gfile.GFile(panoptic_annotations_file, 'r') as fid:
panoptic_annotations = json.load(fid)
img_to_panoptic_annotation = dict()
logging.info('Building panoptic index.')
for annotation in panoptic_annotations['annotations']:
image_id = annotation['image_id']
img_to_panoptic_annotation[image_id] = annotation
is_category_thing = dict()
for category_info in panoptic_annotations['categories']:
is_category_thing[category_info['id']] = category_info['isthing'] == 1
missing_annotation_count = 0
images = panoptic_annotations['images']
for image in images:
image_id = image['id']
if image_id not in img_to_panoptic_annotation:
missing_annotation_count += 1
logging.info(
'%d images are missing panoptic annotations.', missing_annotation_count)
return img_to_panoptic_annotation, is_category_thing
def _load_images_info(images_info_file): def _load_images_info(images_info_file):
with tf.io.gfile.GFile(images_info_file, 'r') as fid: with tf.io.gfile.GFile(images_info_file, 'r') as fid:
info_dict = json.load(fid) info_dict = json.load(fid)
...@@ -294,11 +431,15 @@ def _load_images_info(images_info_file): ...@@ -294,11 +431,15 @@ def _load_images_info(images_info_file):
def generate_annotations(images, image_dirs, def generate_annotations(images, image_dirs,
panoptic_masks_dir=None,
img_to_obj_annotation=None, img_to_obj_annotation=None,
img_to_caption_annotation=None, id_to_name_map=None, img_to_caption_annotation=None,
img_to_panoptic_annotation=None,
is_category_thing=None,
id_to_name_map=None,
include_panoptic_masks=False,
include_masks=False): include_masks=False):
"""Generator for COCO annotations.""" """Generator for COCO annotations."""
for image in images: for image in images:
object_annotation = (img_to_obj_annotation.get(image['id'], None) if object_annotation = (img_to_obj_annotation.get(image['id'], None) if
img_to_obj_annotation else None) img_to_obj_annotation else None)
...@@ -306,8 +447,11 @@ def generate_annotations(images, image_dirs, ...@@ -306,8 +447,11 @@ def generate_annotations(images, image_dirs,
caption_annotaion = (img_to_caption_annotation.get(image['id'], None) if caption_annotaion = (img_to_caption_annotation.get(image['id'], None) if
img_to_caption_annotation else None) img_to_caption_annotation else None)
yield (image, image_dirs, object_annotation, id_to_name_map, panoptic_annotation = (img_to_panoptic_annotation.get(image['id'], None) if
caption_annotaion, include_masks) img_to_panoptic_annotation else None)
yield (image, image_dirs, panoptic_masks_dir, object_annotation,
id_to_name_map, caption_annotaion, panoptic_annotation,
is_category_thing, include_panoptic_masks, include_masks)
def _create_tf_record_from_coco_annotations(images_info_file, def _create_tf_record_from_coco_annotations(images_info_file,
...@@ -316,6 +460,9 @@ def _create_tf_record_from_coco_annotations(images_info_file, ...@@ -316,6 +460,9 @@ def _create_tf_record_from_coco_annotations(images_info_file,
num_shards, num_shards,
object_annotations_file=None, object_annotations_file=None,
caption_annotations_file=None, caption_annotations_file=None,
panoptic_masks_dir=None,
panoptic_annotations_file=None,
include_panoptic_masks=False,
include_masks=False): include_masks=False):
"""Loads COCO annotation json files and converts to tf.Record format. """Loads COCO annotation json files and converts to tf.Record format.
...@@ -331,6 +478,10 @@ def _create_tf_record_from_coco_annotations(images_info_file, ...@@ -331,6 +478,10 @@ def _create_tf_record_from_coco_annotations(images_info_file,
num_shards: Number of output files to create. num_shards: Number of output files to create.
object_annotations_file: JSON file containing bounding box annotations. object_annotations_file: JSON file containing bounding box annotations.
caption_annotations_file: JSON file containing caption annotations. caption_annotations_file: JSON file containing caption annotations.
panoptic_masks_dir: Directory containing panoptic masks.
panoptic_annotations_file: JSON file containing panoptic annotations.
include_panoptic_masks: Whether to include 'category_mask'
and 'instance_mask', which is required by the panoptic quality evaluator.
include_masks: Whether to include instance segmentations masks include_masks: Whether to include instance segmentations masks
(PNG encoded) in the result. default: False. (PNG encoded) in the result. default: False.
""" """
...@@ -342,16 +493,29 @@ def _create_tf_record_from_coco_annotations(images_info_file, ...@@ -342,16 +493,29 @@ def _create_tf_record_from_coco_annotations(images_info_file,
img_to_obj_annotation = None img_to_obj_annotation = None
img_to_caption_annotation = None img_to_caption_annotation = None
id_to_name_map = None id_to_name_map = None
img_to_panoptic_annotation = None
is_category_thing = None
if object_annotations_file: if object_annotations_file:
img_to_obj_annotation, id_to_name_map = ( img_to_obj_annotation, id_to_name_map = (
_load_object_annotations(object_annotations_file)) _load_object_annotations(object_annotations_file))
if caption_annotations_file: if caption_annotations_file:
img_to_caption_annotation = ( img_to_caption_annotation = (
_load_caption_annotations(caption_annotations_file)) _load_caption_annotations(caption_annotations_file))
if panoptic_annotations_file:
img_to_panoptic_annotation, is_category_thing = (
_load_panoptic_annotations(panoptic_annotations_file))
coco_annotations_iter = generate_annotations( coco_annotations_iter = generate_annotations(
images, image_dirs, img_to_obj_annotation, img_to_caption_annotation, images=images,
id_to_name_map=id_to_name_map, include_masks=include_masks) image_dirs=image_dirs,
panoptic_masks_dir=panoptic_masks_dir,
img_to_obj_annotation=img_to_obj_annotation,
img_to_caption_annotation=img_to_caption_annotation,
img_to_panoptic_annotation=img_to_panoptic_annotation,
is_category_thing=is_category_thing,
id_to_name_map=id_to_name_map,
include_panoptic_masks=include_panoptic_masks,
include_masks=include_masks)
num_skipped = tfrecord_lib.write_tf_record_dataset( num_skipped = tfrecord_lib.write_tf_record_dataset(
output_path, coco_annotations_iter, create_tf_example, num_shards) output_path, coco_annotations_iter, create_tf_example, num_shards)
...@@ -380,6 +544,9 @@ def main(_): ...@@ -380,6 +544,9 @@ def main(_):
FLAGS.num_shards, FLAGS.num_shards,
FLAGS.object_annotations_file, FLAGS.object_annotations_file,
FLAGS.caption_annotations_file, FLAGS.caption_annotations_file,
FLAGS.panoptic_masks_dir,
FLAGS.panoptic_annotations_file,
FLAGS.include_panoptic_masks,
FLAGS.include_masks) FLAGS.include_masks)
......
...@@ -15,7 +15,7 @@ done ...@@ -15,7 +15,7 @@ done
cocosplit_url="dl.yf.io/fs-det/datasets/cocosplit" cocosplit_url="dl.yf.io/fs-det/datasets/cocosplit"
wget --recursive --no-parent -q --show-progress --progress=bar:force:noscroll \ wget --recursive --no-parent -q --show-progress --progress=bar:force:noscroll \
-P "${tmp_dir}" -A "trainvalno5k.json,5k.json,*10shot*.json,*30shot*.json" \ -P "${tmp_dir}" -A "trainvalno5k.json,5k.json,*1shot*.json,*3shot*.json,*5shot*.json,*10shot*.json,*30shot*.json" \
"http://${cocosplit_url}/" "http://${cocosplit_url}/"
mv "${tmp_dir}/${cocosplit_url}/"* "${tmp_dir}" mv "${tmp_dir}/${cocosplit_url}/"* "${tmp_dir}"
rm -rf "${tmp_dir}/${cocosplit_url}/" rm -rf "${tmp_dir}/${cocosplit_url}/"
...@@ -24,7 +24,7 @@ python process_coco_few_shot_json_files.py \ ...@@ -24,7 +24,7 @@ python process_coco_few_shot_json_files.py \
--logtostderr --workdir="${tmp_dir}" --logtostderr --workdir="${tmp_dir}"
for seed in {0..9}; do for seed in {0..9}; do
for shots in 10 30; do for shots in 1 3 5 10 30; do
python create_coco_tf_record.py \ python create_coco_tf_record.py \
--logtostderr \ --logtostderr \
--image_dir="${base_image_dir}/train2014" \ --image_dir="${base_image_dir}/train2014" \
......
...@@ -53,7 +53,7 @@ CATEGORIES = ['airplane', 'apple', 'backpack', 'banana', 'baseball bat', ...@@ -53,7 +53,7 @@ CATEGORIES = ['airplane', 'apple', 'backpack', 'banana', 'baseball bat',
'traffic light', 'train', 'truck', 'tv', 'umbrella', 'vase', 'traffic light', 'train', 'truck', 'tv', 'umbrella', 'vase',
'wine glass', 'zebra'] 'wine glass', 'zebra']
SEEDS = list(range(10)) SEEDS = list(range(10))
SHOTS = [10, 30] SHOTS = [1, 3, 5, 10, 30]
FILE_SUFFIXES = collections.defaultdict(list) FILE_SUFFIXES = collections.defaultdict(list)
for _seed, _shots in itertools.product(SEEDS, SHOTS): for _seed, _shots in itertools.product(SEEDS, SHOTS):
......
...@@ -100,8 +100,13 @@ def image_info_to_feature_dict(height, width, filename, image_id, ...@@ -100,8 +100,13 @@ def image_info_to_feature_dict(height, width, filename, image_id,
} }
def encode_binary_mask_as_png(binary_mask): def read_image(image_path):
pil_image = Image.fromarray(binary_mask) pil_image = Image.open(image_path)
return np.asarray(pil_image)
def encode_mask_as_png(mask):
pil_image = Image.fromarray(mask)
output_io = io.BytesIO() output_io = io.BytesIO()
pil_image.save(output_io, format='PNG') pil_image.save(output_io, format='PNG')
return output_io.getvalue() return output_io.getvalue()
......
...@@ -38,7 +38,6 @@ class TfExampleDecoder(decoder.Decoder): ...@@ -38,7 +38,6 @@ class TfExampleDecoder(decoder.Decoder):
self._regenerate_source_id = regenerate_source_id self._regenerate_source_id = regenerate_source_id
self._keys_to_features = { self._keys_to_features = {
'image/encoded': tf.io.FixedLenFeature((), tf.string), 'image/encoded': tf.io.FixedLenFeature((), tf.string),
'image/source_id': tf.io.FixedLenFeature((), tf.string),
'image/height': tf.io.FixedLenFeature((), tf.int64), 'image/height': tf.io.FixedLenFeature((), tf.int64),
'image/width': tf.io.FixedLenFeature((), tf.int64), 'image/width': tf.io.FixedLenFeature((), tf.int64),
'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32), 'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
...@@ -54,6 +53,10 @@ class TfExampleDecoder(decoder.Decoder): ...@@ -54,6 +53,10 @@ class TfExampleDecoder(decoder.Decoder):
self._keys_to_features.update({ self._keys_to_features.update({
'image/object/mask': tf.io.VarLenFeature(tf.string), 'image/object/mask': tf.io.VarLenFeature(tf.string),
}) })
if not regenerate_source_id:
self._keys_to_features.update({
'image/source_id': tf.io.FixedLenFeature((), tf.string),
})
def _decode_image(self, parsed_tensors): def _decode_image(self, parsed_tensors):
"""Decodes the image and set its static shape.""" """Decodes the image and set its static shape."""
......
...@@ -131,7 +131,6 @@ def convert_predictions_to_coco_annotations(predictions): ...@@ -131,7 +131,6 @@ def convert_predictions_to_coco_annotations(predictions):
""" """
coco_predictions = [] coco_predictions = []
num_batches = len(predictions['source_id']) num_batches = len(predictions['source_id'])
batch_size = predictions['source_id'][0].shape[0]
max_num_detections = predictions['detection_classes'][0].shape[1] max_num_detections = predictions['detection_classes'][0].shape[1]
use_outer_box = 'detection_outer_boxes' in predictions use_outer_box = 'detection_outer_boxes' in predictions
for i in range(num_batches): for i in range(num_batches):
...@@ -144,6 +143,7 @@ def convert_predictions_to_coco_annotations(predictions): ...@@ -144,6 +143,7 @@ def convert_predictions_to_coco_annotations(predictions):
else: else:
mask_boxes = predictions['detection_boxes'] mask_boxes = predictions['detection_boxes']
batch_size = predictions['source_id'][i].shape[0]
for j in range(batch_size): for j in range(batch_size):
if 'detection_masks' in predictions: if 'detection_masks' in predictions:
image_masks = mask_ops.paste_instance_masks( image_masks = mask_ops.paste_instance_masks(
...@@ -211,9 +211,9 @@ def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None): ...@@ -211,9 +211,9 @@ def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None):
gt_annotations = [] gt_annotations = []
num_batches = len(groundtruths['source_id']) num_batches = len(groundtruths['source_id'])
batch_size = groundtruths['source_id'][0].shape[0]
for i in range(num_batches): for i in range(num_batches):
max_num_instances = groundtruths['classes'][i].shape[1] max_num_instances = groundtruths['classes'][i].shape[1]
batch_size = groundtruths['source_id'][i].shape[0]
for j in range(batch_size): for j in range(batch_size):
num_instances = groundtruths['num_detections'][i][j] num_instances = groundtruths['num_detections'][i][j]
if num_instances > max_num_instances: if num_instances > max_num_instances:
......
...@@ -342,9 +342,10 @@ Berkin Akin, Suyog Gupta, and Andrew Howard ...@@ -342,9 +342,10 @@ Berkin Akin, Suyog Gupta, and Andrew Howard
""" """
MNMultiMAX_BLOCK_SPECS = { MNMultiMAX_BLOCK_SPECS = {
'spec_name': 'MobileNetMultiMAX', 'spec_name': 'MobileNetMultiMAX',
'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters', 'block_spec_schema': [
'activation', 'expand_ratio', 'block_fn', 'kernel_size', 'strides', 'filters', 'activation',
'use_normalization', 'use_bias', 'is_output'], 'expand_ratio', 'use_normalization', 'use_bias', 'is_output'
],
'block_specs': [ 'block_specs': [
('convbn', 3, 2, 32, 'relu', None, True, False, False), ('convbn', 3, 2, 32, 'relu', None, True, False, False),
('invertedbottleneck', 3, 2, 32, 'relu', 3., None, False, True), ('invertedbottleneck', 3, 2, 32, 'relu', 3., None, False, True),
...@@ -363,15 +364,18 @@ MNMultiMAX_BLOCK_SPECS = { ...@@ -363,15 +364,18 @@ MNMultiMAX_BLOCK_SPECS = {
('invertedbottleneck', 5, 1, 160, 'relu', 4., None, False, True), ('invertedbottleneck', 5, 1, 160, 'relu', 4., None, False, True),
('convbn', 1, 1, 960, 'relu', None, True, False, False), ('convbn', 1, 1, 960, 'relu', None, True, False, False),
('gpooling', None, None, None, None, None, None, None, False), ('gpooling', None, None, None, None, None, None, None, False),
('convbn', 1, 1, 1280, 'relu', None, False, True, False), # Remove bias and add batch norm for the last layer to support QAT
# and achieve slightly better accuracy.
('convbn', 1, 1, 1280, 'relu', None, True, False, False),
] ]
} }
MNMultiAVG_BLOCK_SPECS = { MNMultiAVG_BLOCK_SPECS = {
'spec_name': 'MobileNetMultiAVG', 'spec_name': 'MobileNetMultiAVG',
'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters', 'block_spec_schema': [
'activation', 'expand_ratio', 'block_fn', 'kernel_size', 'strides', 'filters', 'activation',
'use_normalization', 'use_bias', 'is_output'], 'expand_ratio', 'use_normalization', 'use_bias', 'is_output'
],
'block_specs': [ 'block_specs': [
('convbn', 3, 2, 32, 'relu', None, True, False, False), ('convbn', 3, 2, 32, 'relu', None, True, False, False),
('invertedbottleneck', 3, 2, 32, 'relu', 3., None, False, False), ('invertedbottleneck', 3, 2, 32, 'relu', 3., None, False, False),
...@@ -392,7 +396,9 @@ MNMultiAVG_BLOCK_SPECS = { ...@@ -392,7 +396,9 @@ MNMultiAVG_BLOCK_SPECS = {
('invertedbottleneck', 5, 1, 192, 'relu', 4., None, False, True), ('invertedbottleneck', 5, 1, 192, 'relu', 4., None, False, True),
('convbn', 1, 1, 960, 'relu', None, True, False, False), ('convbn', 1, 1, 960, 'relu', None, True, False, False),
('gpooling', None, None, None, None, None, None, None, False), ('gpooling', None, None, None, None, None, None, None, False),
('convbn', 1, 1, 1280, 'relu', None, False, True, False), # Remove bias and add batch norm for the last layer to support QAT
# and achieve slightly better accuracy.
('convbn', 1, 1, 1280, 'relu', None, True, False, False),
] ]
} }
......
...@@ -158,10 +158,10 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -158,10 +158,10 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
('MobileNetV3Small', 0.75): 1026552, ('MobileNetV3Small', 0.75): 1026552,
('MobileNetV3EdgeTPU', 1.0): 2849312, ('MobileNetV3EdgeTPU', 1.0): 2849312,
('MobileNetV3EdgeTPU', 0.75): 1737288, ('MobileNetV3EdgeTPU', 0.75): 1737288,
('MobileNetMultiAVG', 1.0): 3700576, ('MobileNetMultiAVG', 1.0): 3704416,
('MobileNetMultiAVG', 0.75): 2345864, ('MobileNetMultiAVG', 0.75): 2349704,
('MobileNetMultiMAX', 1.0): 3170720, ('MobileNetMultiMAX', 1.0): 3174560,
('MobileNetMultiMAX', 0.75): 2041976, ('MobileNetMultiMAX', 0.75): 2045816,
} }
input_size = 224 input_size = 224
......
...@@ -32,6 +32,12 @@ layers = tf.keras.layers ...@@ -32,6 +32,12 @@ layers = tf.keras.layers
# Each element in the block configuration is in the following format: # Each element in the block configuration is in the following format:
# (block_fn, num_filters, block_repeats) # (block_fn, num_filters, block_repeats)
RESNET_SPECS = { RESNET_SPECS = {
10: [
('residual', 64, 1),
('residual', 128, 1),
('residual', 256, 1),
('residual', 512, 1),
],
18: [ 18: [
('residual', 64, 2), ('residual', 64, 2),
('residual', 128, 2), ('residual', 128, 2),
...@@ -114,6 +120,7 @@ class ResNet(tf.keras.Model): ...@@ -114,6 +120,7 @@ class ResNet(tf.keras.Model):
replace_stem_max_pool: bool = False, replace_stem_max_pool: bool = False,
se_ratio: Optional[float] = None, se_ratio: Optional[float] = None,
init_stochastic_depth_rate: float = 0.0, init_stochastic_depth_rate: float = 0.0,
scale_stem: bool = True,
activation: str = 'relu', activation: str = 'relu',
use_sync_bn: bool = False, use_sync_bn: bool = False,
norm_momentum: float = 0.99, norm_momentum: float = 0.99,
...@@ -121,6 +128,7 @@ class ResNet(tf.keras.Model): ...@@ -121,6 +128,7 @@ class ResNet(tf.keras.Model):
kernel_initializer: str = 'VarianceScaling', kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bn_trainable: bool = True,
**kwargs): **kwargs):
"""Initializes a ResNet model. """Initializes a ResNet model.
...@@ -138,6 +146,7 @@ class ResNet(tf.keras.Model): ...@@ -138,6 +146,7 @@ class ResNet(tf.keras.Model):
with a stride-2 conv, with a stride-2 conv,
se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer. se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
init_stochastic_depth_rate: A `float` of initial stochastic depth rate. init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
scale_stem: A `bool` of whether to scale stem layers.
activation: A `str` name of the activation function. activation: A `str` name of the activation function.
use_sync_bn: If True, use synchronized batch normalization. use_sync_bn: If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average. norm_momentum: A `float` of normalization momentum for the moving average.
...@@ -147,6 +156,8 @@ class ResNet(tf.keras.Model): ...@@ -147,6 +156,8 @@ class ResNet(tf.keras.Model):
Conv2D. Default to None. Conv2D. Default to None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D. bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
Default to None. Default to None.
bn_trainable: A `bool` that indicates whether batch norm layers should be
trainable. Default to True.
**kwargs: Additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
self._model_id = model_id self._model_id = model_id
...@@ -157,6 +168,7 @@ class ResNet(tf.keras.Model): ...@@ -157,6 +168,7 @@ class ResNet(tf.keras.Model):
self._replace_stem_max_pool = replace_stem_max_pool self._replace_stem_max_pool = replace_stem_max_pool
self._se_ratio = se_ratio self._se_ratio = se_ratio
self._init_stochastic_depth_rate = init_stochastic_depth_rate self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._scale_stem = scale_stem
self._use_sync_bn = use_sync_bn self._use_sync_bn = use_sync_bn
self._activation = activation self._activation = activation
self._norm_momentum = norm_momentum self._norm_momentum = norm_momentum
...@@ -168,6 +180,7 @@ class ResNet(tf.keras.Model): ...@@ -168,6 +180,7 @@ class ResNet(tf.keras.Model):
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
self._bn_trainable = bn_trainable
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1 bn_axis = -1
...@@ -177,9 +190,10 @@ class ResNet(tf.keras.Model): ...@@ -177,9 +190,10 @@ class ResNet(tf.keras.Model):
# Build ResNet. # Build ResNet.
inputs = tf.keras.Input(shape=input_specs.shape[1:]) inputs = tf.keras.Input(shape=input_specs.shape[1:])
stem_depth_multiplier = self._depth_multiplier if scale_stem else 1.0
if stem_type == 'v0': if stem_type == 'v0':
x = layers.Conv2D( x = layers.Conv2D(
filters=int(64 * self._depth_multiplier), filters=int(64 * stem_depth_multiplier),
kernel_size=7, kernel_size=7,
strides=2, strides=2,
use_bias=False, use_bias=False,
...@@ -189,12 +203,15 @@ class ResNet(tf.keras.Model): ...@@ -189,12 +203,15 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
inputs) inputs)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x) x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
elif stem_type == 'v1': elif stem_type == 'v1':
x = layers.Conv2D( x = layers.Conv2D(
filters=int(32 * self._depth_multiplier), filters=int(32 * stem_depth_multiplier),
kernel_size=3, kernel_size=3,
strides=2, strides=2,
use_bias=False, use_bias=False,
...@@ -204,11 +221,14 @@ class ResNet(tf.keras.Model): ...@@ -204,11 +221,14 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
inputs) inputs)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x) x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
x = layers.Conv2D( x = layers.Conv2D(
filters=int(32 * self._depth_multiplier), filters=int(32 * stem_depth_multiplier),
kernel_size=3, kernel_size=3,
strides=1, strides=1,
use_bias=False, use_bias=False,
...@@ -218,11 +238,14 @@ class ResNet(tf.keras.Model): ...@@ -218,11 +238,14 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
x) x)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x) x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
x = layers.Conv2D( x = layers.Conv2D(
filters=int(64 * self._depth_multiplier), filters=int(64 * stem_depth_multiplier),
kernel_size=3, kernel_size=3,
strides=1, strides=1,
use_bias=False, use_bias=False,
...@@ -232,7 +255,10 @@ class ResNet(tf.keras.Model): ...@@ -232,7 +255,10 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
x) x)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x) x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
else: else:
...@@ -250,7 +276,10 @@ class ResNet(tf.keras.Model): ...@@ -250,7 +276,10 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
x) x)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x) x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
else: else:
...@@ -318,7 +347,8 @@ class ResNet(tf.keras.Model): ...@@ -318,7 +347,8 @@ class ResNet(tf.keras.Model):
activation=self._activation, activation=self._activation,
use_sync_bn=self._use_sync_bn, use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum, norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)( norm_epsilon=self._norm_epsilon,
bn_trainable=self._bn_trainable)(
inputs) inputs)
for _ in range(1, block_repeats): for _ in range(1, block_repeats):
...@@ -335,7 +365,8 @@ class ResNet(tf.keras.Model): ...@@ -335,7 +365,8 @@ class ResNet(tf.keras.Model):
activation=self._activation, activation=self._activation,
use_sync_bn=self._use_sync_bn, use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum, norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)( norm_epsilon=self._norm_epsilon,
bn_trainable=self._bn_trainable)(
x) x)
return tf.keras.layers.Activation('linear', name=name)(x) return tf.keras.layers.Activation('linear', name=name)(x)
...@@ -350,12 +381,14 @@ class ResNet(tf.keras.Model): ...@@ -350,12 +381,14 @@ class ResNet(tf.keras.Model):
'activation': self._activation, 'activation': self._activation,
'se_ratio': self._se_ratio, 'se_ratio': self._se_ratio,
'init_stochastic_depth_rate': self._init_stochastic_depth_rate, 'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
'scale_stem': self._scale_stem,
'use_sync_bn': self._use_sync_bn, 'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum, 'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon, 'norm_epsilon': self._norm_epsilon,
'kernel_initializer': self._kernel_initializer, 'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer, 'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer, 'bias_regularizer': self._bias_regularizer,
'bn_trainable': self._bn_trainable
} }
return config_dict return config_dict
...@@ -390,8 +423,10 @@ def build_resnet( ...@@ -390,8 +423,10 @@ def build_resnet(
replace_stem_max_pool=backbone_cfg.replace_stem_max_pool, replace_stem_max_pool=backbone_cfg.replace_stem_max_pool,
se_ratio=backbone_cfg.se_ratio, se_ratio=backbone_cfg.se_ratio,
init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate, init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
scale_stem=backbone_cfg.scale_stem,
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer,
bn_trainable=backbone_cfg.bn_trainable)
...@@ -28,6 +28,7 @@ from official.vision.beta.modeling.backbones import resnet ...@@ -28,6 +28,7 @@ from official.vision.beta.modeling.backbones import resnet
class ResNetTest(parameterized.TestCase, tf.test.TestCase): class ResNetTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(128, 10, 1),
(128, 18, 1), (128, 18, 1),
(128, 34, 1), (128, 34, 1),
(128, 50, 4), (128, 50, 4),
...@@ -38,6 +39,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -38,6 +39,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
endpoint_filter_scale): endpoint_filter_scale):
"""Test creation of ResNet family models.""" """Test creation of ResNet family models."""
resnet_params = { resnet_params = {
10: 4915904,
18: 11190464, 18: 11190464,
34: 21306048, 34: 21306048,
50: 23561152, 50: 23561152,
...@@ -126,6 +128,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -126,6 +128,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
resnetd_shortcut=False, resnetd_shortcut=False,
replace_stem_max_pool=False, replace_stem_max_pool=False,
init_stochastic_depth_rate=0.0, init_stochastic_depth_rate=0.0,
scale_stem=True,
use_sync_bn=False, use_sync_bn=False,
activation='relu', activation='relu',
norm_momentum=0.99, norm_momentum=0.99,
...@@ -133,7 +136,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -133,7 +136,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
kernel_initializer='VarianceScaling', kernel_initializer='VarianceScaling',
kernel_regularizer=None, kernel_regularizer=None,
bias_regularizer=None, bias_regularizer=None,
) bn_trainable=True)
network = resnet.ResNet(**kwargs) network = resnet.ResNet(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
......
...@@ -93,23 +93,6 @@ class ClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase): ...@@ -93,23 +93,6 @@ class ClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase):
def test_mobilenet_network_creation(self, mobilenet_model_id, def test_mobilenet_network_creation(self, mobilenet_model_id,
filter_size_scale): filter_size_scale):
"""Test for creation of a MobileNet classifier.""" """Test for creation of a MobileNet classifier."""
mobilenet_params = {
('MobileNetV1', 1.0): 4254889,
('MobileNetV1', 0.75): 2602745,
('MobileNetV2', 1.0): 3540265,
('MobileNetV2', 0.75): 2664345,
('MobileNetV3Large', 1.0): 5508713,
('MobileNetV3Large', 0.75): 4013897,
('MobileNetV3Small', 1.0): 2555993,
('MobileNetV3Small', 0.75): 2052577,
('MobileNetV3EdgeTPU', 1.0): 4131593,
('MobileNetV3EdgeTPU', 0.75): 3019569,
('MobileNetMultiAVG', 1.0): 4982857,
('MobileNetMultiAVG', 0.75): 3628145,
('MobileNetMultiMAX', 1.0): 4453001,
('MobileNetMultiMAX', 0.75): 3324257,
}
inputs = np.random.rand(2, 224, 224, 3) inputs = np.random.rand(2, 224, 224, 3)
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
...@@ -123,8 +106,6 @@ class ClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase): ...@@ -123,8 +106,6 @@ class ClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase):
num_classes=num_classes, num_classes=num_classes,
dropout_rate=0.2, dropout_rate=0.2,
) )
self.assertEqual(model.count_params(),
mobilenet_params[(mobilenet_model_id, filter_size_scale)])
logits = model(inputs) logits = model(inputs)
self.assertAllEqual([2, num_classes], logits.numpy().shape) self.assertAllEqual([2, num_classes], logits.numpy().shape)
......
...@@ -42,6 +42,7 @@ class ASPP(tf.keras.layers.Layer): ...@@ -42,6 +42,7 @@ class ASPP(tf.keras.layers.Layer):
kernel_initializer: str = 'VarianceScaling', kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
interpolation: str = 'bilinear', interpolation: str = 'bilinear',
use_depthwise_convolution: bool = False,
**kwargs): **kwargs):
"""Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer. """Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer.
...@@ -64,6 +65,8 @@ class ASPP(tf.keras.layers.Layer): ...@@ -64,6 +65,8 @@ class ASPP(tf.keras.layers.Layer):
interpolation: A `str` of interpolation method. It should be one of interpolation: A `str` of interpolation method. It should be one of
`bilinear`, `nearest`, `bicubic`, `area`, `lanczos3`, `lanczos5`, `bilinear`, `nearest`, `bicubic`, `area`, `lanczos3`, `lanczos5`,
`gaussian`, or `mitchellcubic`. `gaussian`, or `mitchellcubic`.
use_depthwise_convolution: If True depthwise separable convolutions will
be added to the Atrous spatial pyramid pooling.
**kwargs: Additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(ASPP, self).__init__(**kwargs) super(ASPP, self).__init__(**kwargs)
...@@ -80,6 +83,7 @@ class ASPP(tf.keras.layers.Layer): ...@@ -80,6 +83,7 @@ class ASPP(tf.keras.layers.Layer):
'kernel_initializer': kernel_initializer, 'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer, 'kernel_regularizer': kernel_regularizer,
'interpolation': interpolation, 'interpolation': interpolation,
'use_depthwise_convolution': use_depthwise_convolution,
} }
def build(self, input_shape): def build(self, input_shape):
...@@ -100,7 +104,9 @@ class ASPP(tf.keras.layers.Layer): ...@@ -100,7 +104,9 @@ class ASPP(tf.keras.layers.Layer):
dropout=self._config_dict['dropout_rate'], dropout=self._config_dict['dropout_rate'],
kernel_initializer=self._config_dict['kernel_initializer'], kernel_initializer=self._config_dict['kernel_initializer'],
kernel_regularizer=self._config_dict['kernel_regularizer'], kernel_regularizer=self._config_dict['kernel_regularizer'],
interpolation=self._config_dict['interpolation']) interpolation=self._config_dict['interpolation'],
use_depthwise_convolution=self._config_dict['use_depthwise_convolution']
)
def call(self, inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]: def call(self, inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
"""Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input. """Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input.
...@@ -167,6 +173,7 @@ def build_aspp_decoder( ...@@ -167,6 +173,7 @@ def build_aspp_decoder(
level=decoder_cfg.level, level=decoder_cfg.level,
dilation_rates=decoder_cfg.dilation_rates, dilation_rates=decoder_cfg.dilation_rates,
num_filters=decoder_cfg.num_filters, num_filters=decoder_cfg.num_filters,
use_depthwise_convolution=decoder_cfg.use_depthwise_convolution,
pool_kernel_size=decoder_cfg.pool_kernel_size, pool_kernel_size=decoder_cfg.pool_kernel_size,
dropout_rate=decoder_cfg.dropout_rate, dropout_rate=decoder_cfg.dropout_rate,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
......
...@@ -70,6 +70,7 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase): ...@@ -70,6 +70,7 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase):
kernel_regularizer=None, kernel_regularizer=None,
interpolation='bilinear', interpolation='bilinear',
dropout_rate=0.2, dropout_rate=0.2,
use_depthwise_convolution='false',
) )
network = aspp.ASPP(**kwargs) network = aspp.ASPP(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment