Commit 8d6f5271 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Unify DataDecoder from retinanet and maskrcnn and move it into common so it...

Unify DataDecoder from retinanet and maskrcnn and move it into common so it can be used by other tasks. Add DataDecoder as optional field for classification and segmentation DataConfig so overriding with DataDecoder can work for all tasks.

PiperOrigin-RevId: 392493075
parent ee8b45fb
......@@ -15,15 +15,44 @@
# Lint as: python3
"""Common configurations."""
import dataclasses
from typing import Optional
# Import libraries
import dataclasses
# Import libraries
from official.core import config_definitions as cfg
from official.modeling import hyperparams
@dataclasses.dataclass
class TfExampleDecoder(hyperparams.Config):
"""A simple TF Example decoder config."""
regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
@dataclasses.dataclass
class TfExampleDecoderLabelMap(hyperparams.Config):
"""TF Example decoder with label map config."""
regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
label_map: str = ''
@dataclasses.dataclass
class DataDecoder(hyperparams.OneOfConfig):
"""Data decoder config.
Attributes:
type: 'str', type of data decoder be used, one of the fields below.
simple_decoder: simple TF Example decoder config.
label_map_decoder: TF Example decoder with label map config.
"""
type: Optional[str] = 'simple_decoder'
simple_decoder: TfExampleDecoder = TfExampleDecoder()
label_map_decoder: TfExampleDecoderLabelMap = TfExampleDecoderLabelMap()
@dataclasses.dataclass
class RandAugment(hyperparams.Config):
"""Configuration for RandAugment."""
......
......@@ -14,11 +14,10 @@
# Lint as: python3
"""Image classification configuration definition."""
import dataclasses
import os
from typing import List, Optional
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
......@@ -44,6 +43,7 @@ class DataConfig(cfg.DataConfig):
image_field_key: str = 'image/encoded'
label_field_key: str = 'image/class/label'
decode_jpeg_only: bool = True
decoder: Optional[common.DataDecoder] = common.DataDecoder()
# Keep for backward compatibility.
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'.
......
......@@ -29,26 +29,6 @@ from official.vision.beta.configs import backbones
# pylint: disable=missing-class-docstring
@dataclasses.dataclass
class TfExampleDecoder(hyperparams.Config):
regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
@dataclasses.dataclass
class TfExampleDecoderLabelMap(hyperparams.Config):
regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
label_map: str = ''
@dataclasses.dataclass
class DataDecoder(hyperparams.OneOfConfig):
type: Optional[str] = 'simple_decoder'
simple_decoder: TfExampleDecoder = TfExampleDecoder()
label_map_decoder: TfExampleDecoderLabelMap = TfExampleDecoderLabelMap()
@dataclasses.dataclass
class Parser(hyperparams.Config):
num_channels: int = 3
......@@ -73,7 +53,7 @@ class DataConfig(cfg.DataConfig):
global_batch_size: int = 0
is_training: bool = False
dtype: str = 'bfloat16'
decoder: DataDecoder = DataDecoder()
decoder: common.DataDecoder = common.DataDecoder()
parser: Parser = Parser()
shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord'
......
......@@ -15,9 +15,9 @@
# Lint as: python3
"""RetinaNet configuration definition."""
import dataclasses
import os
from typing import List, Optional
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
......@@ -29,22 +29,22 @@ from official.vision.beta.configs import backbones
# pylint: disable=missing-class-docstring
# Keep for backward compatibility.
@dataclasses.dataclass
class TfExampleDecoder(hyperparams.Config):
regenerate_source_id: bool = False
class TfExampleDecoder(common.TfExampleDecoder):
"""A simple TF Example decoder config."""
# Keep for backward compatibility.
@dataclasses.dataclass
class TfExampleDecoderLabelMap(hyperparams.Config):
regenerate_source_id: bool = False
label_map: str = ''
class TfExampleDecoderLabelMap(common.TfExampleDecoderLabelMap):
"""TF Example decoder with label map config."""
# Keep for backward compatibility.
@dataclasses.dataclass
class DataDecoder(hyperparams.OneOfConfig):
type: Optional[str] = 'simple_decoder'
simple_decoder: TfExampleDecoder = TfExampleDecoder()
label_map_decoder: TfExampleDecoderLabelMap = TfExampleDecoderLabelMap()
class DataDecoder(common.DataDecoder):
"""Data decoder config."""
@dataclasses.dataclass
......@@ -67,7 +67,7 @@ class DataConfig(cfg.DataConfig):
global_batch_size: int = 0
is_training: bool = False
dtype: str = 'bfloat16'
decoder: DataDecoder = DataDecoder()
decoder: common.DataDecoder = common.DataDecoder()
parser: Parser = Parser()
shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord'
......
......@@ -14,10 +14,10 @@
# Lint as: python3
"""Semantic segmentation configuration definition."""
import dataclasses
import os
from typing import List, Optional, Union
import dataclasses
import numpy as np
from official.core import exp_factory
......@@ -52,6 +52,7 @@ class DataConfig(cfg.DataConfig):
aug_rand_hflip: bool = True
drop_remainder: bool = True
file_type: str = 'tfrecord'
decoder: Optional[common.DataDecoder] = common.DataDecoder()
@dataclasses.dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment