input_reader.py 3.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data loader and input processing."""

from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

import tensorflow.compat.v2 as tf

from typing import Text, Optional
from official.modeling.hyperparams import params_dict
from official.vision.detection.dataloader import factory
from official.vision.detection.dataloader import mode_keys as ModeKeys


class InputFn(object):
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
31
  """Input function that creates dataset from files."""
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

  def __init__(self,
               file_pattern: Text,
               params: params_dict.ParamsDict,
               mode: Text,
               batch_size: int,
               num_examples: Optional[int] = -1):
    """Initialize.

    Args:
      file_pattern: the file pattern for the data example (TFRecords).
      params: the parameter object for constructing example parser and model.
      mode: ModeKeys.TRAIN or ModeKeys.Eval
      batch_size: the data batch size.
      num_examples: If positive, only takes this number of examples and raise
        tf.errors.OutOfRangeError after that. If non-positive, it will be
        ignored.
    """
    assert file_pattern is not None
    assert mode is not None
    assert batch_size is not None
    self._file_pattern = file_pattern
    self._mode = mode
    self._is_training = (mode == ModeKeys.TRAIN)
    self._batch_size = batch_size
    self._num_examples = num_examples
    self._parser_fn = factory.parser_generator(params, mode)
    self._dataset_fn = tf.data.TFRecordDataset

Yeqing Li's avatar
Yeqing Li committed
61
62
63
64
65
66
    self._input_sharding = (not self._is_training)
    try:
      if self._is_training:
        self._input_sharding = params.train.input_sharding
      else:
        self._input_sharding = params.eval.input_sharding
67
    except AttributeError:
Yeqing Li's avatar
Yeqing Li committed
68
69
      pass

Yeqing Li's avatar
Yeqing Li committed
70
  def __call__(self, ctx=None, batch_size: int = None):
71
72
73
74
    """Provides tf.data.Dataset object.

    Args:
      ctx: context object.
Yeqing Li's avatar
Yeqing Li committed
75
      batch_size: expected batch size input data.
76
77
78
79
80
81
82
83
84
85

    Returns:
      tf.data.Dataset object.
    """
    if not batch_size:
      batch_size = self._batch_size
    assert batch_size is not None
    dataset = tf.data.Dataset.list_files(
        self._file_pattern, shuffle=self._is_training)

Yeqing Li's avatar
Yeqing Li committed
86
    if self._input_sharding and ctx and ctx.num_input_pipelines > 1:
87
      dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
88
89
    dataset = dataset.cache()

90
91
92
    if self._is_training:
      dataset = dataset.repeat()

Yeqing Li's avatar
Yeqing Li committed
93
    dataset = dataset.interleave(
94
        map_func=self._dataset_fn, cycle_length=32,
Yeqing Li's avatar
Yeqing Li committed
95
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
96
97

    if self._is_training:
Yeqing Li's avatar
Yeqing Li committed
98
99
100
      # Large shuffle size is critical for 2vm input pipeline. Can use small
      # value (e.g. 64) for 1vm.
      dataset = dataset.shuffle(1000)
101
102
103
104
    if self._num_examples > 0:
      dataset = dataset.take(self._num_examples)

    # Parses the fetched records to input tensors for model function.
Yeqing Li's avatar
Yeqing Li committed
105
106
    dataset = dataset.map(
        self._parser_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
107
    dataset = dataset.batch(batch_size, drop_remainder=True)
Yeqing Li's avatar
Yeqing Li committed
108
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
109
    return dataset