"src/include/Array.hpp" did not exist on "498e71b09822406b1b050c5eb03edebfe04038a6"
Unverified Commit 83b87f05 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added configs for dataloader, losses and task

parent 5e478a4c
......@@ -15,18 +15,46 @@
"""Panoptic Deeplab configuration definition."""
import dataclasses
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union
from official.core import config_definitions as cfg
from official.modeling import hyperparams
from official.vision.beta.configs import common
from official.vision.beta.configs import backbones
from official.vision.beta.configs import decoders
_COCO_INPUT_PATH_BASE = 'coco/tfrecords'
_COCO_TRAIN_EXAMPLES = 118287
_COCO_VAL_EXAMPLES = 5000
@dataclasses.dataclass
class Parser(hyperparams.Config):
ignore_label: int = 0
# If resize_eval_groundtruth is set to False, original image sizes are used
# for eval. In that case, groundtruth_padded_size has to be specified too to
# allow for batching the variable input sizes of images.
resize_eval_groundtruth: bool = True
groundtruth_padded_size: List[int] = dataclasses.field(default_factory=list)
aug_scale_min: float = 1.0
aug_scale_max: float = 1.0
aug_rand_hflip: bool = True
sigma: float = 8.0
dtype = 'float32'
@dataclasses.dataclass
class DataDecoder(common.DataDecoder):
"""Data decoder config."""
simple_decoder: common.TfExampleDecoder = common.TfExampleDecoder()
@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""Input config for training."""
decoder: DataDecoder = DataDecoder()
parser: Parser = Parser()
file_type: str = 'tfrecord'
@dataclasses.dataclass
class PanopticDeeplabHead(hyperparams.Config):
"""Panoptic Deeplab head config."""
......@@ -39,7 +67,6 @@ class PanopticDeeplabHead(hyperparams.Config):
low_level: Union[List[int], Tuple[int]] = (3, 2)
low_level_num_filters: Union[List[int], Tuple[int]] = (64, 32)
@dataclasses.dataclass
class SemanticHead(PanopticDeeplabHead):
"""Semantic head config."""
......@@ -53,6 +80,8 @@ class InstanceHead(PanopticDeeplabHead):
@dataclasses.dataclass
class PanopticDeeplabPostProcessor(hyperparams.Config):
"""Panoptic Deeplab PostProcessing config."""
output_size: List[int] = dataclasses.field(
default_factory=list)
center_score_threshold: float = 0.1
thing_class_ids: List[int] = dataclasses.field(default_factory=list)
label_divisor: int = 256 * 256 * 256
......@@ -60,11 +89,12 @@ class PanopticDeeplabPostProcessor(hyperparams.Config):
ignore_label: int = 0
nms_kernel: int = 41
keep_k_centers: int = 400
rescale_predictions: bool = True
@dataclasses.dataclass
class PanopticDeeplab(hyperparams.Config):
"""Panoptic Deeplab model config."""
num_classes: int = 0
num_classes: int = 2
input_size: List[int] = dataclasses.field(default_factory=list)
min_level: int = 3
max_level: int = 6
......@@ -75,4 +105,44 @@ class PanopticDeeplab(hyperparams.Config):
semantic_head: SemanticHead = SemanticHead()
instance_head: InstanceHead = InstanceHead()
shared_decoder: bool = False
generate_panoptic_masks: bool = True
post_processor: PanopticDeeplabPostProcessor = PanopticDeeplabPostProcessor()
@dataclasses.dataclass
class Losses(hyperparams.Config):
label_smoothing: float = 0.0
ignore_label: int = 0
class_weights: List[float] = dataclasses.field(default_factory=list)
l2_weight_decay: float = 1e-4
use_groundtruth_dimension: bool = True
top_k_percent_pixels: float = 0.15
segmentation_loss_weight: float = 1.0
center_heatmap_loss_weight: float = 200
center_offset_loss_weight: float = 0.01
@dataclasses.dataclass
class Evaluation(hyperparams.Config):
""" Evaluation config """
ignored_label: int = 0
max_instances_per_category: int = 256
offset: int = 256 * 256 * 256
is_thing: List[float] = dataclasses.field(
default_factory=list)
rescale_predictions: bool = True
report_per_class_pq: bool = False
report_per_class_iou: bool = False
report_train_mean_iou: bool = True # Turning this off can speed up training.
@dataclasses.dataclass
class PanopticDeeplabTask(cfg.TaskConfig):
model: PanopticDeeplab = PanopticDeeplab()
train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(
is_training=False,
drop_remainder=False)
losses: Losses = Losses()
init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[str, List[str]] = 'all' # all, backbone, and/or decoder
annotation_file: Optional[str] = None
evaluation: Evaluation = Evaluation()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment