Commit cb386cfb authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Make dataset_fn as an optional input argument in Mask R-CNN.

PiperOrigin-RevId: 453274537
parent 5997d0df
...@@ -19,7 +19,7 @@ from typing import Any, Dict, Optional, List, Tuple, Mapping ...@@ -19,7 +19,7 @@ from typing import Any, Dict, Optional, List, Tuple, Mapping
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.common import dataset_fn from official.common import dataset_fn as dataset_fn_lib
from official.core import base_task from official.core import base_task
from official.core import task_factory from official.core import task_factory
from official.vision.configs import maskrcnn as exp_cfg from official.vision.configs import maskrcnn as exp_cfg
...@@ -118,9 +118,11 @@ class MaskRCNNTask(base_task.Task): ...@@ -118,9 +118,11 @@ class MaskRCNNTask(base_task.Task):
logging.info('Finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
def build_inputs(self, def build_inputs(
params: exp_cfg.DataConfig, self,
input_context: Optional[tf.distribute.InputContext] = None): params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None,
dataset_fn: Optional[dataset_fn_lib.PossibleDatasetType] = None):
"""Build input dataset.""" """Build input dataset."""
decoder_cfg = params.decoder.get() decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder': if params.decoder.type == 'simple_decoder':
...@@ -157,9 +159,12 @@ class MaskRCNNTask(base_task.Task): ...@@ -157,9 +159,12 @@ class MaskRCNNTask(base_task.Task):
include_mask=self._task_config.model.include_mask, include_mask=self._task_config.model.include_mask,
mask_crop_size=params.parser.mask_crop_size) mask_crop_size=params.parser.mask_crop_size)
if not dataset_fn:
dataset_fn = dataset_fn_lib.pick_dataset_fn(params.file_type)
reader = input_reader_factory.input_reader_generator( reader = input_reader_factory.input_reader_generator(
params, params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), dataset_fn=dataset_fn,
decoder_fn=decoder.decode, decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training)) parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context) dataset = reader.read(input_context=input_context)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment