input_reader.py 3.62 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

15
16
17
18
19
20
21
"""Data loader and input processing."""

from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

22
import tensorflow as tf
23
24
25
26
27
28
29
30

from typing import Text, Optional
from official.modeling.hyperparams import params_dict
from official.vision.detection.dataloader import factory
from official.vision.detection.dataloader import mode_keys as ModeKeys


class InputFn(object):
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
31
  """Input function that creates dataset from files."""
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

  def __init__(self,
               file_pattern: Text,
               params: params_dict.ParamsDict,
               mode: Text,
               batch_size: int,
               num_examples: Optional[int] = -1):
    """Initialize.

    Args:
      file_pattern: the file pattern for the data example (TFRecords).
      params: the parameter object for constructing example parser and model.
      mode: ModeKeys.TRAIN or ModeKeys.Eval
      batch_size: the data batch size.
      num_examples: If positive, only takes this number of examples and raise
        tf.errors.OutOfRangeError after that. If non-positive, it will be
        ignored.
    """
    assert file_pattern is not None
    assert mode is not None
    assert batch_size is not None
    self._file_pattern = file_pattern
    self._mode = mode
    self._is_training = (mode == ModeKeys.TRAIN)
    self._batch_size = batch_size
    self._num_examples = num_examples
    self._parser_fn = factory.parser_generator(params, mode)
    self._dataset_fn = tf.data.TFRecordDataset

Yeqing Li's avatar
Yeqing Li committed
61
62
63
64
65
66
    self._input_sharding = (not self._is_training)
    try:
      if self._is_training:
        self._input_sharding = params.train.input_sharding
      else:
        self._input_sharding = params.eval.input_sharding
67
    except AttributeError:
Yeqing Li's avatar
Yeqing Li committed
68
69
      pass

Yeqing Li's avatar
Yeqing Li committed
70
  def __call__(self, ctx=None, batch_size: int = None):
71
72
73
74
    """Provides tf.data.Dataset object.

    Args:
      ctx: context object.
Yeqing Li's avatar
Yeqing Li committed
75
      batch_size: expected batch size input data.
76
77
78
79
80
81
82
83
84
85

    Returns:
      tf.data.Dataset object.
    """
    if not batch_size:
      batch_size = self._batch_size
    assert batch_size is not None
    dataset = tf.data.Dataset.list_files(
        self._file_pattern, shuffle=self._is_training)

Yeqing Li's avatar
Yeqing Li committed
86
    if self._input_sharding and ctx and ctx.num_input_pipelines > 1:
87
      dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
88
89
    dataset = dataset.cache()

90
91
92
    if self._is_training:
      dataset = dataset.repeat()

Yeqing Li's avatar
Yeqing Li committed
93
    dataset = dataset.interleave(
Hongkun Yu's avatar
Hongkun Yu committed
94
95
        map_func=self._dataset_fn,
        cycle_length=32,
Yeqing Li's avatar
Yeqing Li committed
96
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
97
98

    if self._is_training:
Yeqing Li's avatar
Yeqing Li committed
99
      dataset = dataset.shuffle(1000)
100
101
102
103
    if self._num_examples > 0:
      dataset = dataset.take(self._num_examples)

    # Parses the fetched records to input tensors for model function.
Yeqing Li's avatar
Yeqing Li committed
104
105
    dataset = dataset.map(
        self._parser_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
106
    dataset = dataset.batch(batch_size, drop_remainder=True)
Yeqing Li's avatar
Yeqing Li committed
107
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
108
    return dataset