Commit 4a2b9846 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Unify DataDecoder from retinanet and maskrcnn and move it into common so it...

Unify DataDecoder from retinanet and maskrcnn and move it into common so it can be used by other tasks. Add DataDecoder as optional field for classification and segmentation DataConfig so overriding with DataDecoder can work for all tasks.

PiperOrigin-RevId: 392493075
parent e31d1693
...@@ -15,15 +15,44 @@ ...@@ -15,15 +15,44 @@
# Lint as: python3 # Lint as: python3
"""Common configurations.""" """Common configurations."""
import dataclasses
from typing import Optional from typing import Optional
# Import libraries
import dataclasses # Import libraries
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.modeling import hyperparams from official.modeling import hyperparams
@dataclasses.dataclass
class TfExampleDecoder(hyperparams.Config):
"""A simple TF Example decoder config."""
regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
@dataclasses.dataclass
class TfExampleDecoderLabelMap(hyperparams.Config):
"""TF Example decoder with label map config."""
regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
label_map: str = ''
@dataclasses.dataclass
class DataDecoder(hyperparams.OneOfConfig):
"""Data decoder config.
Attributes:
type: 'str', type of data decoder be used, one of the fields below.
simple_decoder: simple TF Example decoder config.
label_map_decoder: TF Example decoder with label map config.
"""
type: Optional[str] = 'simple_decoder'
simple_decoder: TfExampleDecoder = TfExampleDecoder()
label_map_decoder: TfExampleDecoderLabelMap = TfExampleDecoderLabelMap()
@dataclasses.dataclass @dataclasses.dataclass
class RandAugment(hyperparams.Config): class RandAugment(hyperparams.Config):
"""Configuration for RandAugment.""" """Configuration for RandAugment."""
......
...@@ -14,11 +14,10 @@ ...@@ -14,11 +14,10 @@
# Lint as: python3 # Lint as: python3
"""Image classification configuration definition.""" """Image classification configuration definition."""
import dataclasses
import os import os
from typing import List, Optional from typing import List, Optional
import dataclasses
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
...@@ -44,6 +43,7 @@ class DataConfig(cfg.DataConfig): ...@@ -44,6 +43,7 @@ class DataConfig(cfg.DataConfig):
image_field_key: str = 'image/encoded' image_field_key: str = 'image/encoded'
label_field_key: str = 'image/class/label' label_field_key: str = 'image/class/label'
decode_jpeg_only: bool = True decode_jpeg_only: bool = True
decoder: Optional[common.DataDecoder] = common.DataDecoder()
# Keep for backward compatibility. # Keep for backward compatibility.
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'. aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'.
......
...@@ -29,26 +29,6 @@ from official.vision.beta.configs import backbones ...@@ -29,26 +29,6 @@ from official.vision.beta.configs import backbones
# pylint: disable=missing-class-docstring # pylint: disable=missing-class-docstring
@dataclasses.dataclass
class TfExampleDecoder(hyperparams.Config):
regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
@dataclasses.dataclass
class TfExampleDecoderLabelMap(hyperparams.Config):
regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
label_map: str = ''
@dataclasses.dataclass
class DataDecoder(hyperparams.OneOfConfig):
type: Optional[str] = 'simple_decoder'
simple_decoder: TfExampleDecoder = TfExampleDecoder()
label_map_decoder: TfExampleDecoderLabelMap = TfExampleDecoderLabelMap()
@dataclasses.dataclass @dataclasses.dataclass
class Parser(hyperparams.Config): class Parser(hyperparams.Config):
num_channels: int = 3 num_channels: int = 3
...@@ -73,7 +53,7 @@ class DataConfig(cfg.DataConfig): ...@@ -73,7 +53,7 @@ class DataConfig(cfg.DataConfig):
global_batch_size: int = 0 global_batch_size: int = 0
is_training: bool = False is_training: bool = False
dtype: str = 'bfloat16' dtype: str = 'bfloat16'
decoder: DataDecoder = DataDecoder() decoder: common.DataDecoder = common.DataDecoder()
parser: Parser = Parser() parser: Parser = Parser()
shuffle_buffer_size: int = 10000 shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord' file_type: str = 'tfrecord'
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
# Lint as: python3 # Lint as: python3
"""RetinaNet configuration definition.""" """RetinaNet configuration definition."""
import dataclasses
import os import os
from typing import List, Optional from typing import List, Optional
import dataclasses
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
...@@ -29,22 +29,22 @@ from official.vision.beta.configs import backbones ...@@ -29,22 +29,22 @@ from official.vision.beta.configs import backbones
# pylint: disable=missing-class-docstring # pylint: disable=missing-class-docstring
# Keep for backward compatibility.
@dataclasses.dataclass @dataclasses.dataclass
class TfExampleDecoder(hyperparams.Config): class TfExampleDecoder(common.TfExampleDecoder):
regenerate_source_id: bool = False """A simple TF Example decoder config."""
# Keep for backward compatibility.
@dataclasses.dataclass @dataclasses.dataclass
class TfExampleDecoderLabelMap(hyperparams.Config): class TfExampleDecoderLabelMap(common.TfExampleDecoderLabelMap):
regenerate_source_id: bool = False """TF Example decoder with label map config."""
label_map: str = ''
# Keep for backward compatibility.
@dataclasses.dataclass @dataclasses.dataclass
class DataDecoder(hyperparams.OneOfConfig): class DataDecoder(common.DataDecoder):
type: Optional[str] = 'simple_decoder' """Data decoder config."""
simple_decoder: TfExampleDecoder = TfExampleDecoder()
label_map_decoder: TfExampleDecoderLabelMap = TfExampleDecoderLabelMap()
@dataclasses.dataclass @dataclasses.dataclass
...@@ -67,7 +67,7 @@ class DataConfig(cfg.DataConfig): ...@@ -67,7 +67,7 @@ class DataConfig(cfg.DataConfig):
global_batch_size: int = 0 global_batch_size: int = 0
is_training: bool = False is_training: bool = False
dtype: str = 'bfloat16' dtype: str = 'bfloat16'
decoder: DataDecoder = DataDecoder() decoder: common.DataDecoder = common.DataDecoder()
parser: Parser = Parser() parser: Parser = Parser()
shuffle_buffer_size: int = 10000 shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord' file_type: str = 'tfrecord'
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
# Lint as: python3 # Lint as: python3
"""Semantic segmentation configuration definition.""" """Semantic segmentation configuration definition."""
import dataclasses
import os import os
from typing import List, Optional, Union from typing import List, Optional, Union
import dataclasses
import numpy as np import numpy as np
from official.core import exp_factory from official.core import exp_factory
...@@ -52,6 +52,7 @@ class DataConfig(cfg.DataConfig): ...@@ -52,6 +52,7 @@ class DataConfig(cfg.DataConfig):
aug_rand_hflip: bool = True aug_rand_hflip: bool = True
drop_remainder: bool = True drop_remainder: bool = True
file_type: str = 'tfrecord' file_type: str = 'tfrecord'
decoder: Optional[common.DataDecoder] = common.DataDecoder()
@dataclasses.dataclass @dataclasses.dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment