data_pipeline.py 7.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Data pipeline for the Ranking model.

This module defines various input datasets for the Ranking model.
"""

from typing import List
import tensorflow as tf

from official.recommendation.ranking.configs import config


class CriteoTsvReader:
  """Input reader callable for pre-processed Criteo data.

  Raw Criteo data is assumed to be preprocessed in the following way:
  1. Missing values are replaced with zeros.
  2. Negative values are replaced with zeros.
  3. Integer features are transformed by log(x+1) and are hence tf.float32.
  4. Categorical data is bucketized and are hence tf.int32.
  """

  def __init__(self,
               file_pattern: str,
               params: config.DataConfig,
               num_dense_features: int,
               vocab_sizes: List[int],
               use_synthetic_data: bool = False):
    self._file_pattern = file_pattern
    self._params = params
    self._num_dense_features = num_dense_features
    self._vocab_sizes = vocab_sizes
    self._use_synthetic_data = use_synthetic_data

  def __call__(self, ctx: tf.distribute.InputContext) -> tf.data.Dataset:
    params = self._params
    # Per replica batch size.
    batch_size = ctx.get_per_replica_batch_size(
        params.global_batch_size) if ctx else params.global_batch_size
    if self._use_synthetic_data:
      return self._generate_synthetic_data(ctx, batch_size)

    @tf.function
    def _parse_fn(example: tf.Tensor):
      """Parser function for pre-processed Criteo TSV records."""
      label_defaults = [[0.0]]
      dense_defaults = [
          [0.0] for _ in range(self._num_dense_features)
      ]
      num_sparse_features = len(self._vocab_sizes)
      categorical_defaults = [
          [0] for _ in range(num_sparse_features)
      ]
      record_defaults = label_defaults + dense_defaults + categorical_defaults
      fields = tf.io.decode_csv(
          example, record_defaults, field_delim='\t', na_value='-1')

      num_labels = 1
      label = tf.reshape(fields[0], [batch_size, 1])

      features = {}
      num_dense = len(dense_defaults)

      dense_features = []
      offset = num_labels
      for idx in range(num_dense):
        dense_features.append(fields[idx + offset])
      features['dense_features'] = tf.stack(dense_features, axis=1)

      offset += num_dense
      features['sparse_features'] = {}

      for idx in range(num_sparse_features):
        features['sparse_features'][str(idx)] = fields[idx + offset]

      return features, label

    filenames = tf.data.Dataset.list_files(self._file_pattern, shuffle=False)

    # Shard the full dataset according to host number.
    # Each host will get 1 / num_of_hosts portion of the data.
    if params.sharding and ctx and ctx.num_input_pipelines > 1:
      filenames = filenames.shard(ctx.num_input_pipelines,
                                  ctx.input_pipeline_id)

    num_shards_per_host = 1
    if params.sharding:
      num_shards_per_host = params.num_shards_per_host

    def make_dataset(shard_index):
      filenames_for_shard = filenames.shard(num_shards_per_host, shard_index)
      dataset = tf.data.TextLineDataset(filenames_for_shard)
      if params.is_training:
        dataset = dataset.repeat()
      dataset = dataset.batch(batch_size, drop_remainder=True)
      dataset = dataset.map(_parse_fn,
                            num_parallel_calls=tf.data.experimental.AUTOTUNE)
      return dataset

    indices = tf.data.Dataset.range(num_shards_per_host)
    dataset = indices.interleave(
        map_func=make_dataset,
        cycle_length=params.cycle_length,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    return dataset

  def _generate_synthetic_data(self, ctx: tf.distribute.InputContext,
                               batch_size: int) -> tf.data.Dataset:
    """Creates synthetic data based on the parameter batch size.

    Args:
      ctx: Input Context
      batch_size: per replica batch size.

    Returns:
      The synthetic dataset.
    """
    params = self._params
    num_dense = self._num_dense_features
    num_replicas = ctx.num_replicas_in_sync if ctx else 1

    if params.is_training:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
139
      dataset_size = 1000 * batch_size * num_replicas
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    else:
      dataset_size = 1000 * batch_size * num_replicas
    dense_tensor = tf.random.uniform(
        shape=(dataset_size, num_dense), maxval=1.0, dtype=tf.float32)

    sparse_tensors = []
    for size in self._vocab_sizes:
      sparse_tensors.append(
          tf.random.uniform(
              shape=(dataset_size,), maxval=int(size), dtype=tf.int32))

    sparse_tensor_elements = {
        str(i): sparse_tensors[i] for i in range(len(sparse_tensors))
    }

    # the mean is in [0, 1] interval.
    dense_tensor_mean = tf.math.reduce_mean(dense_tensor, axis=1)

    sparse_tensors = tf.stack(sparse_tensors, axis=-1)
    sparse_tensors_mean = tf.math.reduce_sum(sparse_tensors, axis=1)
    # the mean is in [0, 1] interval.
    sparse_tensors_mean = tf.cast(sparse_tensors_mean, dtype=tf.float32)
    sparse_tensors_mean /= sum(self._vocab_sizes)
    # the label is in [0, 1] interval.
    label_tensor = (dense_tensor_mean + sparse_tensors_mean) / 2.0
    # Using the threshold 0.5 to convert to 0/1 labels.
    label_tensor = tf.cast(label_tensor + 0.5, tf.int32)

    input_elem = {'dense_features': dense_tensor,
                  'sparse_features': sparse_tensor_elements}, label_tensor

    dataset = tf.data.Dataset.from_tensor_slices(input_elem)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
172
    dataset = dataset.cache()
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    if params.is_training:
      dataset = dataset.repeat()

    return dataset.batch(batch_size, drop_remainder=True)


def train_input_fn(params: config.Task) -> CriteoTsvReader:
  """Returns callable object of batched training examples.

  Args:
    params: hyperparams to create input pipelines.

  Returns:
    CriteoTsvReader callable for training dataset.
  """
  return CriteoTsvReader(
      file_pattern=params.train_data.input_path,
      params=params.train_data,
      vocab_sizes=params.model.vocab_sizes,
      num_dense_features=params.model.num_dense_features,
      use_synthetic_data=params.use_synthetic_data)


def eval_input_fn(params: config.Task) -> CriteoTsvReader:
  """Returns callable object of batched eval examples.

  Args:
    params: hyperparams to create input pipelines.

  Returns:
    CriteoTsvReader callable for eval dataset.
  """

  return CriteoTsvReader(
      file_pattern=params.validation_data.input_path,
      params=params.validation_data,
      vocab_sizes=params.model.vocab_sizes,
      num_dense_features=params.model.num_dense_features,
      use_synthetic_data=params.use_synthetic_data)