Unverified Commit b4cd3351 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge pull request #19 from srihari-humbarwadi/panoptic-deeplab-resnet50_fix

Use ResNet v1 as default backbone
parents 44f6d511 11adf3e2
...@@ -50,11 +50,13 @@ class Parser(hyperparams.Config): ...@@ -50,11 +50,13 @@ class Parser(hyperparams.Config):
small_instance_weight: float = 3.0 small_instance_weight: float = 3.0
dtype = 'float32' dtype = 'float32'
@dataclasses.dataclass @dataclasses.dataclass
class DataDecoder(common.DataDecoder): class DataDecoder(common.DataDecoder):
"""Data decoder config.""" """Data decoder config."""
simple_decoder: common.TfExampleDecoder = common.TfExampleDecoder() simple_decoder: common.TfExampleDecoder = common.TfExampleDecoder()
@dataclasses.dataclass @dataclasses.dataclass
class DataConfig(cfg.DataConfig): class DataConfig(cfg.DataConfig):
"""Input config for training.""" """Input config for training."""
...@@ -62,6 +64,7 @@ class DataConfig(cfg.DataConfig): ...@@ -62,6 +64,7 @@ class DataConfig(cfg.DataConfig):
parser: Parser = Parser() parser: Parser = Parser()
file_type: str = 'tfrecord' file_type: str = 'tfrecord'
@dataclasses.dataclass @dataclasses.dataclass
class PanopticDeeplabHead(hyperparams.Config): class PanopticDeeplabHead(hyperparams.Config):
"""Panoptic Deeplab head config.""" """Panoptic Deeplab head config."""
...@@ -75,16 +78,19 @@ class PanopticDeeplabHead(hyperparams.Config): ...@@ -75,16 +78,19 @@ class PanopticDeeplabHead(hyperparams.Config):
low_level_num_filters: Union[List[int], Tuple[int]] = (64, 32) low_level_num_filters: Union[List[int], Tuple[int]] = (64, 32)
fusion_num_output_filters: int = 256 fusion_num_output_filters: int = 256
@dataclasses.dataclass @dataclasses.dataclass
class SemanticHead(PanopticDeeplabHead): class SemanticHead(PanopticDeeplabHead):
"""Semantic head config.""" """Semantic head config."""
prediction_kernel_size: int = 1 prediction_kernel_size: int = 1
@dataclasses.dataclass @dataclasses.dataclass
class InstanceHead(PanopticDeeplabHead): class InstanceHead(PanopticDeeplabHead):
"""Instance head config.""" """Instance head config."""
prediction_kernel_size: int = 1 prediction_kernel_size: int = 1
@dataclasses.dataclass @dataclasses.dataclass
class PanopticDeeplabPostProcessor(hyperparams.Config): class PanopticDeeplabPostProcessor(hyperparams.Config):
"""Panoptic Deeplab PostProcessing config.""" """Panoptic Deeplab PostProcessing config."""
...@@ -99,6 +105,7 @@ class PanopticDeeplabPostProcessor(hyperparams.Config): ...@@ -99,6 +105,7 @@ class PanopticDeeplabPostProcessor(hyperparams.Config):
keep_k_centers: int = 200 keep_k_centers: int = 200
rescale_predictions: bool = True rescale_predictions: bool = True
@dataclasses.dataclass @dataclasses.dataclass
class PanopticDeeplab(hyperparams.Config): class PanopticDeeplab(hyperparams.Config):
"""Panoptic Deeplab model config.""" """Panoptic Deeplab model config."""
...@@ -116,6 +123,7 @@ class PanopticDeeplab(hyperparams.Config): ...@@ -116,6 +123,7 @@ class PanopticDeeplab(hyperparams.Config):
generate_panoptic_masks: bool = True generate_panoptic_masks: bool = True
post_processor: PanopticDeeplabPostProcessor = PanopticDeeplabPostProcessor() post_processor: PanopticDeeplabPostProcessor = PanopticDeeplabPostProcessor()
@dataclasses.dataclass @dataclasses.dataclass
class Losses(hyperparams.Config): class Losses(hyperparams.Config):
label_smoothing: float = 0.0 label_smoothing: float = 0.0
...@@ -127,6 +135,7 @@ class Losses(hyperparams.Config): ...@@ -127,6 +135,7 @@ class Losses(hyperparams.Config):
center_heatmap_loss_weight: float = 200 center_heatmap_loss_weight: float = 200
center_offset_loss_weight: float = 0.01 center_offset_loss_weight: float = 0.01
@dataclasses.dataclass @dataclasses.dataclass
class Evaluation(hyperparams.Config): class Evaluation(hyperparams.Config):
""" Evaluation config """ """ Evaluation config """
...@@ -141,6 +150,7 @@ class Evaluation(hyperparams.Config): ...@@ -141,6 +150,7 @@ class Evaluation(hyperparams.Config):
report_per_class_iou: bool = False report_per_class_iou: bool = False
report_train_mean_iou: bool = True # Turning this off can speed up training. report_train_mean_iou: bool = True # Turning this off can speed up training.
@dataclasses.dataclass @dataclasses.dataclass
class PanopticDeeplabTask(cfg.TaskConfig): class PanopticDeeplabTask(cfg.TaskConfig):
model: PanopticDeeplab = PanopticDeeplab() model: PanopticDeeplab = PanopticDeeplab()
...@@ -175,10 +185,9 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig: ...@@ -175,10 +185,9 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
output_stride = 16 output_stride = 16
aspp_dilation_rates = [6, 12, 18] aspp_dilation_rates = [6, 12, 18]
multigrid = [1, 2, 4] multigrid = [1, 2, 4]
stem_type = 'v0' stem_type = 'v1'
level = int(np.math.log2(output_stride)) level = int(np.math.log2(output_stride))
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig( runtime=cfg.RuntimeConfig(
mixed_precision_dtype='bfloat16', enable_xla=True), mixed_precision_dtype='bfloat16', enable_xla=True),
...@@ -191,9 +200,12 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig: ...@@ -191,9 +200,12 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
backbone=backbones.Backbone( backbone=backbones.Backbone(
type='dilated_resnet', dilated_resnet=backbones.DilatedResNet( type='dilated_resnet', dilated_resnet=backbones.DilatedResNet(
model_id=50, model_id=50,
stem_type=stem_type,
output_stride=output_stride, output_stride=output_stride,
multigrid=multigrid, multigrid=multigrid,
stem_type=stem_type)), se_ratio=0.25,
last_stage_repeats=1,
stochastic_depth_drop_rate=0.2)),
decoder=decoders.Decoder( decoder=decoders.Decoder(
type='aspp', type='aspp',
aspp=decoders.ASPP( aspp=decoders.ASPP(
...@@ -201,6 +213,7 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig: ...@@ -201,6 +213,7 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
num_filters=256, num_filters=256,
pool_kernel_size=input_size[:2], pool_kernel_size=input_size[:2],
dilation_rates=aspp_dilation_rates, dilation_rates=aspp_dilation_rates,
use_depthwise_convolution=True,
dropout_rate=0.1)), dropout_rate=0.1)),
semantic_head=SemanticHead( semantic_head=SemanticHead(
level=level, level=level,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment