input_reader.py 3.57 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

15
16
17
18
19
20
"""Data loader and input processing."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

21
import tensorflow as tf
22
23
24
25
26
27
28
29

from typing import Text, Optional
from official.modeling.hyperparams import params_dict
from official.vision.detection.dataloader import factory
from official.vision.detection.dataloader import mode_keys as ModeKeys


class InputFn(object):
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
30
  """Input function that creates dataset from files."""
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

  def __init__(self,
               file_pattern: Text,
               params: params_dict.ParamsDict,
               mode: Text,
               batch_size: int,
               num_examples: Optional[int] = -1):
    """Initialize.

    Args:
      file_pattern: the file pattern for the data example (TFRecords).
      params: the parameter object for constructing example parser and model.
      mode: ModeKeys.TRAIN or ModeKeys.Eval
      batch_size: the data batch size.
      num_examples: If positive, only takes this number of examples and raise
        tf.errors.OutOfRangeError after that. If non-positive, it will be
        ignored.
    """
    assert file_pattern is not None
    assert mode is not None
    assert batch_size is not None
    self._file_pattern = file_pattern
    self._mode = mode
    self._is_training = (mode == ModeKeys.TRAIN)
    self._batch_size = batch_size
    self._num_examples = num_examples
    self._parser_fn = factory.parser_generator(params, mode)
    self._dataset_fn = tf.data.TFRecordDataset

Yeqing Li's avatar
Yeqing Li committed
60
61
62
63
64
65
    self._input_sharding = (not self._is_training)
    try:
      if self._is_training:
        self._input_sharding = params.train.input_sharding
      else:
        self._input_sharding = params.eval.input_sharding
66
    except AttributeError:
Yeqing Li's avatar
Yeqing Li committed
67
68
      pass

Yeqing Li's avatar
Yeqing Li committed
69
  def __call__(self, ctx=None, batch_size: int = None):
70
71
72
73
    """Provides tf.data.Dataset object.

    Args:
      ctx: context object.
Yeqing Li's avatar
Yeqing Li committed
74
      batch_size: expected batch size input data.
75
76
77
78
79
80
81
82
83
84

    Returns:
      tf.data.Dataset object.
    """
    if not batch_size:
      batch_size = self._batch_size
    assert batch_size is not None
    dataset = tf.data.Dataset.list_files(
        self._file_pattern, shuffle=self._is_training)

Yeqing Li's avatar
Yeqing Li committed
85
    if self._input_sharding and ctx and ctx.num_input_pipelines > 1:
86
      dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
87
88
    dataset = dataset.cache()

89
90
91
    if self._is_training:
      dataset = dataset.repeat()

Yeqing Li's avatar
Yeqing Li committed
92
    dataset = dataset.interleave(
Hongkun Yu's avatar
Hongkun Yu committed
93
94
        map_func=self._dataset_fn,
        cycle_length=32,
Yeqing Li's avatar
Yeqing Li committed
95
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
96
97

    if self._is_training:
Yeqing Li's avatar
Yeqing Li committed
98
      dataset = dataset.shuffle(1000)
99
100
101
102
    if self._num_examples > 0:
      dataset = dataset.take(self._num_examples)

    # Parses the fetched records to input tensors for model function.
Yeqing Li's avatar
Yeqing Li committed
103
104
    dataset = dataset.map(
        self._parser_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
105
    dataset = dataset.batch(batch_size, drop_remainder=True)
Yeqing Li's avatar
Yeqing Li committed
106
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
107
    return dataset