Unverified Commit b4cd3351 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge pull request #19 from srihari-humbarwadi/panoptic-deeplab-resnet50_fix

Use ResNet v1 as default backbone
parents 44f6d511 11adf3e2
......@@ -50,11 +50,13 @@ class Parser(hyperparams.Config):
small_instance_weight: float = 3.0
dtype = 'float32'
@dataclasses.dataclass
class DataDecoder(common.DataDecoder):
"""Data decoder config."""
simple_decoder: common.TfExampleDecoder = common.TfExampleDecoder()
@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""Input config for training."""
......@@ -62,6 +64,7 @@ class DataConfig(cfg.DataConfig):
parser: Parser = Parser()
file_type: str = 'tfrecord'
@dataclasses.dataclass
class PanopticDeeplabHead(hyperparams.Config):
"""Panoptic Deeplab head config."""
......@@ -75,16 +78,19 @@ class PanopticDeeplabHead(hyperparams.Config):
low_level_num_filters: Union[List[int], Tuple[int]] = (64, 32)
fusion_num_output_filters: int = 256
@dataclasses.dataclass
class SemanticHead(PanopticDeeplabHead):
"""Semantic head config."""
prediction_kernel_size: int = 1
@dataclasses.dataclass
class InstanceHead(PanopticDeeplabHead):
"""Instance head config."""
prediction_kernel_size: int = 1
@dataclasses.dataclass
class PanopticDeeplabPostProcessor(hyperparams.Config):
"""Panoptic Deeplab PostProcessing config."""
......@@ -99,6 +105,7 @@ class PanopticDeeplabPostProcessor(hyperparams.Config):
keep_k_centers: int = 200
rescale_predictions: bool = True
@dataclasses.dataclass
class PanopticDeeplab(hyperparams.Config):
"""Panoptic Deeplab model config."""
......@@ -116,6 +123,7 @@ class PanopticDeeplab(hyperparams.Config):
generate_panoptic_masks: bool = True
post_processor: PanopticDeeplabPostProcessor = PanopticDeeplabPostProcessor()
@dataclasses.dataclass
class Losses(hyperparams.Config):
label_smoothing: float = 0.0
......@@ -127,6 +135,7 @@ class Losses(hyperparams.Config):
center_heatmap_loss_weight: float = 200
center_offset_loss_weight: float = 0.01
@dataclasses.dataclass
class Evaluation(hyperparams.Config):
""" Evaluation config """
......@@ -141,6 +150,7 @@ class Evaluation(hyperparams.Config):
report_per_class_iou: bool = False
report_train_mean_iou: bool = True # Turning this off can speed up training.
@dataclasses.dataclass
class PanopticDeeplabTask(cfg.TaskConfig):
model: PanopticDeeplab = PanopticDeeplab()
......@@ -175,10 +185,9 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
output_stride = 16
aspp_dilation_rates = [6, 12, 18]
multigrid = [1, 2, 4]
stem_type = 'v0'
stem_type = 'v1'
level = int(np.math.log2(output_stride))
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(
mixed_precision_dtype='bfloat16', enable_xla=True),
......@@ -191,9 +200,12 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
backbone=backbones.Backbone(
type='dilated_resnet', dilated_resnet=backbones.DilatedResNet(
model_id=50,
stem_type=stem_type,
output_stride=output_stride,
multigrid=multigrid,
stem_type=stem_type)),
se_ratio=0.25,
last_stage_repeats=1,
stochastic_depth_drop_rate=0.2)),
decoder=decoders.Decoder(
type='aspp',
aspp=decoders.ASPP(
......@@ -201,6 +213,7 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
num_filters=256,
pool_kernel_size=input_size[:2],
dilation_rates=aspp_dilation_rates,
use_depthwise_convolution=True,
dropout_rate=0.1)),
semantic_head=SemanticHead(
level=level,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment