Unverified Commit 07d7cb85 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added experiment config for `panoptic_deeplab_coco`

parent 9e2fbc6b
...@@ -13,15 +13,19 @@ ...@@ -13,15 +13,19 @@
# limitations under the License. # limitations under the License.
"""Panoptic Deeplab configuration definition.""" """Panoptic Deeplab configuration definition."""
import os
import dataclasses import dataclasses
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.vision.beta.configs import common from official.modeling import optimization
from official.vision.beta.configs import backbones from official.vision.configs import common
from official.vision.beta.configs import decoders from official.vision.configs import backbones
from official.vision.configs import decoders
_COCO_INPUT_PATH_BASE = 'coco/tfrecords' _COCO_INPUT_PATH_BASE = 'coco/tfrecords'
...@@ -66,6 +70,7 @@ class PanopticDeeplabHead(hyperparams.Config): ...@@ -66,6 +70,7 @@ class PanopticDeeplabHead(hyperparams.Config):
upsample_factor: int = 1 upsample_factor: int = 1
low_level: Union[List[int], Tuple[int]] = (3, 2) low_level: Union[List[int], Tuple[int]] = (3, 2)
low_level_num_filters: Union[List[int], Tuple[int]] = (64, 32) low_level_num_filters: Union[List[int], Tuple[int]] = (64, 32)
fusion_num_output_filters: int = 256
@dataclasses.dataclass @dataclasses.dataclass
class SemanticHead(PanopticDeeplabHead): class SemanticHead(PanopticDeeplabHead):
...@@ -144,5 +149,156 @@ class PanopticDeeplabTask(cfg.TaskConfig): ...@@ -144,5 +149,156 @@ class PanopticDeeplabTask(cfg.TaskConfig):
losses: Losses = Losses() losses: Losses = Losses()
init_checkpoint: Optional[str] = None init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[str, List[str]] = 'all' # all, backbone, and/or decoder init_checkpoint_modules: Union[str, List[str]] = 'all' # all, backbone, and/or decoder
annotation_file: Optional[str] = None
evaluation: Evaluation = Evaluation() evaluation: Evaluation = Evaluation()
@exp_factory.register_config_factory('panoptic_deeplab_coco')
def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
"""COCO panoptic segmentation with Panoptic Deeplab."""
train_steps = 200000
train_batch_size = 64
eval_batch_size = 1
steps_per_epoch = _COCO_TRAIN_EXAMPLES // train_batch_size
validation_steps = _COCO_VAL_EXAMPLES // eval_batch_size
num_panoptic_categories = 201
num_thing_categories = 91
ignore_label = 0
is_thing = [False]
for idx in range(1, num_panoptic_categories):
is_thing.append(True if idx <= num_thing_categories else False)
input_size = [640, 640, 3]
output_stride = 16
aspp_dilation_rates = [6, 12, 18]
multigrid = [1, 2, 4]
stem_type = 'v1'
level = int(np.math.log2(output_stride))
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(
mixed_precision_dtype='float32', enable_xla=True),
task=PanopticDeeplabTask(
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400', # pylint: disable=line-too-long
init_checkpoint_modules=['backbone'],
model=PanopticDeeplab(
num_classes=num_panoptic_categories,
input_size=input_size,
backbone=backbones.Backbone(
type='dilated_resnet', dilated_resnet=backbones.DilatedResNet(
model_id=50,
output_stride=output_stride,
multigrid=multigrid,
stem_type=stem_type)),
decoder=decoders.Decoder(
type='aspp',
aspp=decoders.ASPP(
level=level,
num_filters=256,
dilation_rates=aspp_dilation_rates)),
semantic_head=SemanticHead(
level=level,
num_convs=2,
num_filters=256,
kernel_size=5,
use_depthwise_convolution=False,
upsample_factor=1,
low_level=(3, 2),
low_level_num_filters=(64, 32),
fusion_num_output_filters=256,
prediction_kernel_size=1),
instance_head=InstanceHead(
level=level,
num_convs=2,
num_filters=32,
kernel_size=5,
use_depthwise_convolution=False,
upsample_factor=1,
low_level=(3, 2),
low_level_num_filters=(32, 16),
fusion_num_output_filters=128,
prediction_kernel_size=1),
shared_decoder=False,
generate_panoptic_masks=True,
post_processor=PanopticDeeplabPostProcessor(
output_size=input_size[:2],
center_score_threshold=0.1,
thing_class_ids=[i for i in range(num_thing_categories)],
label_divisor=256 * 256 * 256,
stuff_area_limit=4096,
ignore_label=ignore_label,
nms_kernel=41,
keep_k_centers=200,
rescale_predictions=True)),
losses=Losses(
label_smoothing=0.0,
ignore_label=ignore_label,
l2_weight_decay=0.0,
use_groundtruth_dimension=True,
top_k_percent_pixels=0.2,
segmentation_loss_weight=1.0,
center_heatmap_loss_weight=200,
center_offset_loss_weight=0.01),
train_data=DataConfig(
input_path=os.path.join(_COCO_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
parser=Parser(
aug_scale_min=0.5,
aug_scale_max=1.5,
aug_rand_hflip=True,
sigma=8.0)),
validation_data=DataConfig(
input_path=os.path.join(_COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size,
parser=Parser(
resize_eval_groundtruth=False,
groundtruth_padded_size=[640, 640],
aug_scale_min=1.0,
aug_scale_max=1.0,
aug_rand_hflip=False,
sigma=8.0),
drop_remainder=False),
evaluation=Evaluation(
ignored_label=ignore_label,
max_instances_per_category=256,
offset=256 * 256 * 256,
is_thing=is_thing,
rescale_predictions=True,
report_per_class_pq=True,
report_per_class_iou=False,
report_train_mean_iou=False)),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=validation_steps,
validation_interval=steps_per_epoch,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adam',
},
'learning_rate': {
'polynomial': {
'initial_learning_rate': 0.001,
'decay_steps': train_steps,
'end_learning_rate': 0.0,
'power': 0.9
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 4 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment