data_async_generation.py 21.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Asynchronously generate TFRecords files for NCF."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import atexit
import contextlib
import datetime
import gc
import multiprocessing
import json
import os
import pickle
import signal
import sys
import tempfile
import time
import timeit
import traceback
import typing

import numpy as np
import tensorflow as tf

from absl import app as absl_app
from absl import flags

from official.datasets import movielens
from official.recommendation import constants as rconst
from official.recommendation import stat_utils
Shawn Wang's avatar
Shawn Wang committed
46
from official.recommendation import popen_helper
47
48


49
50
51
_log_file = None


52
53
def log_msg(msg):
  """Include timestamp info when logging messages to a file."""
54
  if flags.FLAGS.use_tf_logging:
55
56
57
    tf.logging.info(msg)
    return

58
59
  if flags.FLAGS.redirect_logs:
    timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
60
    print("[{}] {}".format(timestamp, msg), file=_log_file)
61
  else:
62
63
64
    print(msg, file=_log_file)
  if _log_file:
    _log_file.flush()
65
66
67
68
69
70


def get_cycle_folder_name(i):
  return "cycle_{}".format(str(i).zfill(5))


71
def _process_shard(args):
72
  # type: ((str, int, int, int, bool)) -> (np.ndarray, np.ndarray, np.ndarray)
73
74
75
76
77
78
79
  """Read a shard of training data and return training vectors.

  Args:
    shard_path: The filepath of the positive instance training shard.
    num_items: The cardinality of the item set.
    num_neg: The number of negatives to generate per positive example.
    seed: Random seed to be used when generating negatives.
80
81
    is_training: Generate training (True) or eval (False) data.
    match_mlperf: Match the MLPerf reference behavior
82
  """
83
  shard_path, num_items, num_neg, seed, is_training, match_mlperf = args
84
  np.random.seed(seed)
85
86
87
88
89
90
91
92
93
94
95
96

  # The choice to store the training shards in files rather than in memory
  # is motivated by the fact that multiprocessing serializes arguments,
  # transmits them to map workers, and then deserializes them. By storing the
  # training shards in files, the serialization work only needs to be done once.
  #
  # A similar effect could be achieved by simply holding pickled bytes in
  # memory, however the processing is not I/O bound and is therefore
  # unnecessary.
  with tf.gfile.Open(shard_path, "rb") as f:
    shard = pickle.load(f)

97
98
99
100
101
102
103
104
  users = shard[rconst.TRAIN_KEY][movielens.USER_COLUMN]
  items = shard[rconst.TRAIN_KEY][movielens.ITEM_COLUMN]

  if not is_training:
    # For eval, there is one positive which was held out from the training set.
    test_positive_dict = dict(zip(
        shard[rconst.EVAL_KEY][movielens.USER_COLUMN],
        shard[rconst.EVAL_KEY][movielens.ITEM_COLUMN]))
105
106
107
108
109
110
111
112
113
114

  delta = users[1:] - users[:-1]
  boundaries = ([0] + (np.argwhere(delta)[:, 0] + 1).tolist() +
                [users.shape[0]])

  user_blocks = []
  item_blocks = []
  label_blocks = []
  for i in range(len(boundaries) - 1):
    assert len(set(users[boundaries[i]:boundaries[i+1]])) == 1
115
116
    current_user = users[boundaries[i]]

117
118
119
120
121
    positive_items = items[boundaries[i]:boundaries[i+1]]
    positive_set = set(positive_items)
    if positive_items.shape[0] != len(positive_set):
      raise ValueError("Duplicate entries detected.")

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    if is_training:
      n_pos = len(positive_set)
      negatives = stat_utils.sample_with_exclusion(
          num_items, positive_set, n_pos * num_neg, replacement=True)

    else:
      if not match_mlperf:
        # The mlperf reference allows the holdout item to appear as a negative.
        # Including it in the positive set makes the eval more stringent,
        # because an appearance of the test item would be removed by
        # deduplication rules. (Effectively resulting in a minute reduction of
        # NUM_EVAL_NEGATIVES)
        positive_set.add(test_positive_dict[current_user])

      negatives = stat_utils.sample_with_exclusion(
          num_items, positive_set, num_neg, replacement=match_mlperf)
      positive_set = [test_positive_dict[current_user]]
      n_pos = len(positive_set)
      assert n_pos == 1

    user_blocks.append(current_user * np.ones(
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        (n_pos * (1 + num_neg),), dtype=np.int32))
    item_blocks.append(
        np.array(list(positive_set) + negatives, dtype=np.uint16))
    labels_for_user = np.zeros((n_pos * (1 + num_neg),), dtype=np.int8)
    labels_for_user[:n_pos] = 1
    label_blocks.append(labels_for_user)

  users_out = np.concatenate(user_blocks)
  items_out = np.concatenate(item_blocks)
  labels_out = np.concatenate(label_blocks)

  assert users_out.shape == items_out.shape == labels_out.shape
  return users_out, items_out, labels_out


158
def _construct_record(users, items, labels=None, dupe_mask=None):
159
160
161
162
163
164
165
166
167
168
169
  """Convert NumPy arrays into a TFRecords entry."""
  feature_dict = {
      movielens.USER_COLUMN: tf.train.Feature(
          bytes_list=tf.train.BytesList(value=[memoryview(users).tobytes()])),
      movielens.ITEM_COLUMN: tf.train.Feature(
          bytes_list=tf.train.BytesList(value=[memoryview(items).tobytes()])),
  }
  if labels is not None:
    feature_dict["labels"] = tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[memoryview(labels).tobytes()]))

170
171
172
173
  if dupe_mask is not None:
    feature_dict[rconst.DUPLICATE_MASK] = tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[memoryview(dupe_mask).tobytes()]))

174
175
176
177
178
179
180
181
182
183
184
185
  return tf.train.Example(
      features=tf.train.Features(feature=feature_dict)).SerializeToString()


def sigint_handler(signal, frame):
  log_msg("Shutting down worker.")


def init_worker():
  signal.signal(signal.SIGINT, sigint_handler)


186
187
188
def _construct_records(
    is_training,          # type: bool
    train_cycle,          # type: typing.Optional[int]
189
190
191
192
    num_workers,          # type: int
    cache_paths,          # type: rconst.Paths
    num_readers,          # type: int
    num_neg,              # type: int
193
    num_positives,        # type: int
194
195
    num_items,            # type: int
    epochs_per_cycle,     # type: int
196
    batch_size,           # type: int
197
    training_shards,      # type: typing.List[str]
198
199
    deterministic=False,  # type: bool
    match_mlperf=False    # type: bool
200
201
202
203
    ):
  """Generate false negatives and write TFRecords files.

  Args:
204
    is_training: Are training records (True) or eval records (False) created.
205
206
207
208
    train_cycle: Integer of which cycle the generated data is for.
    num_workers: Number of multiprocessing workers to use for negative
      generation.
    cache_paths: Paths object with information of where to write files.
209
    num_readers: The number of reader datasets in the input_fn.
210
    num_neg: The number of false negatives per positive example.
211
    num_positives: The number of positive examples. This value is used
212
213
214
215
      to pre-allocate arrays while the imap is still running. (NumPy does not
      allow dynamic arrays.)
    num_items: The cardinality of the item set.
    epochs_per_cycle: The number of epochs worth of data to construct.
216
    batch_size: The expected batch size used during training. This is used
217
218
219
220
221
      to properly batch data when writing TFRecords.
    training_shards: The picked positive examples from which to generate
      negatives.
  """
  st = timeit.default_timer()
222
223
224
225
226
227
228

  if not is_training:
    # Later logic assumes that all items for a given user are in the same batch.
    assert not batch_size % (rconst.NUM_EVAL_NEGATIVES + 1)
    assert num_neg == rconst.NUM_EVAL_NEGATIVES

  assert epochs_per_cycle == 1 or is_training
229
  num_workers = min([num_workers, len(training_shards) * epochs_per_cycle])
230
231
232
233
234
235
236

  num_pts = num_positives * (1 + num_neg)

  # Equivalent to `int(ceil(num_pts / batch_size)) * batch_size`, but without
  # precision concerns
  num_pts_with_padding = (num_pts + batch_size - 1) // batch_size * batch_size
  num_padding = num_pts_with_padding - num_pts
237

238
239
  # We choose a different random seed for each process, so that the processes
  # will not all choose the same random numbers.
240
  process_seeds = [stat_utils.random_int32()
241
                   for _ in training_shards * epochs_per_cycle]
242
243
244
  map_args = [
      (shard, num_items, num_neg, process_seeds[i], is_training, match_mlperf)
      for i, shard in enumerate(training_shards * epochs_per_cycle)]
245

246
  with popen_helper.get_pool(num_workers, init_worker) as pool:
247
248
    map_fn = pool.imap if deterministic else pool.imap_unordered  # pylint: disable=no-member
    data_generator = map_fn(_process_shard, map_args)
249
    data = [
250
251
252
        np.zeros(shape=(num_pts_with_padding,), dtype=np.int32) - 1,
        np.zeros(shape=(num_pts_with_padding,), dtype=np.uint16),
        np.zeros(shape=(num_pts_with_padding,), dtype=np.int8),
253
254
    ]

255
256
257
258
259
260
261
262
    # Training data is shuffled. Evaluation data MUST not be shuffled.
    # Downstream processing depends on the fact that evaluation data for a given
    # user is grouped within a batch.
    if is_training:
      index_destinations = np.random.permutation(num_pts)
    else:
      index_destinations = np.arange(num_pts)

263
264
265
266
267
268
269
270
    start_ind = 0
    for data_segment in data_generator:
      n_in_segment = data_segment[0].shape[0]
      dest = index_destinations[start_ind:start_ind + n_in_segment]
      start_ind += n_in_segment
      for i in range(3):
        data[i][dest] = data_segment[i]

271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
  assert np.sum(data[0] == -1) == num_padding

  if is_training:
    if num_padding:
      # In order to have a full batch, randomly include points from earlier in
      # the batch.
      pad_sample_indices = np.random.randint(
          low=0, high=num_pts, size=(num_padding,))
      dest = np.arange(start=start_ind, stop=start_ind + num_padding)
      start_ind += num_padding
      for i in range(3):
        data[i][dest] = data[i][pad_sample_indices]
  else:
    # For Evaluation, padding is all zeros. The evaluation input_fn knows how
    # to interpret and discard the zero padded entries.
    data[0][num_pts:] = 0
287

288
    # Check that no points were overlooked.
289

290
291
292
  assert not np.sum(data[0] == -1)

  batches_per_file = np.ceil(num_pts_with_padding / batch_size / num_readers)
293
294
295
296
297
298
299
300
  current_file_id = -1
  current_batch_id = -1
  batches_by_file = [[] for _ in range(num_readers)]

  while True:
    current_batch_id += 1
    if (current_batch_id % batches_per_file) == 0:
      current_file_id += 1
301
302
303
304
305
306
307

    start_ind = current_batch_id * batch_size
    end_ind = start_ind + batch_size
    if end_ind > num_pts_with_padding:
      if start_ind != num_pts_with_padding:
        raise ValueError("Batch padding does not line up")
      break
308
309
    batches_by_file[current_file_id].append(current_batch_id)

310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
  if is_training:
    # Empirically it is observed that placing the batch with repeated values at
    # the start rather than the end improves convergence.
    batches_by_file[0][0], batches_by_file[-1][-1] = \
      batches_by_file[-1][-1], batches_by_file[0][0]

  if is_training:
    template = rconst.TRAIN_RECORD_TEMPLATE
    record_dir = os.path.join(cache_paths.train_epoch_dir,
                              get_cycle_folder_name(train_cycle))
    tf.gfile.MakeDirs(record_dir)
  else:
    template = rconst.EVAL_RECORD_TEMPLATE
    record_dir = cache_paths.eval_data_subdir

325
326
  batch_count = 0
  for i in range(num_readers):
327
    fpath = os.path.join(record_dir, template.format(i))
328
329
330
    log_msg("Writing {}".format(fpath))
    with tf.python_io.TFRecordWriter(fpath) as writer:
      for j in batches_by_file[i]:
331
332
333
        start_ind = j * batch_size
        end_ind = start_ind + batch_size
        record_kwargs = dict(
334
335
336
337
            users=data[0][start_ind:end_ind],
            items=data[1][start_ind:end_ind],
        )

338
339
340
341
342
343
        if is_training:
          record_kwargs["labels"] = data[2][start_ind:end_ind]
        else:
          record_kwargs["dupe_mask"] = stat_utils.mask_duplicates(
              record_kwargs["items"].reshape(-1, num_neg + 1),
              axis=1).flatten().astype(np.int8)
344

345
        batch_bytes = _construct_record(**record_kwargs)
346

347
348
        writer.write(batch_bytes)
        batch_count += 1
349

350
351
352
353
354
  # We write to a temp file then atomically rename it to the final file, because
  # writing directly to the final file can cause the main process to read a
  # partially written JSON file.
  ready_file_temp = os.path.join(record_dir, rconst.READY_FILE_TEMP)
  with tf.gfile.Open(ready_file_temp, "w") as f:
355
    json.dump({
356
        "batch_size": batch_size,
357
358
        "batch_count": batch_count,
    }, f)
359
360
  ready_file = os.path.join(record_dir, rconst.READY_FILE)
  tf.gfile.Rename(ready_file_temp, ready_file)
361

362
363
364
365
366
367
  if is_training:
    log_msg("Cycle {} complete. Total time: {:.1f} seconds"
            .format(train_cycle, timeit.default_timer() - st))
  else:
    log_msg("Eval construction complete. Total time: {:.1f} seconds"
            .format(timeit.default_timer() - st))
368
369


370
371
372
373
374
375
def _generation_loop(num_workers,           # type: int
                     cache_paths,           # type: rconst.Paths
                     num_readers,           # type: int
                     num_neg,               # type: int
                     num_train_positives,   # type: int
                     num_items,             # type: int
376
                     num_users,             # type: int
377
378
379
                     epochs_per_cycle,      # type: int
                     train_batch_size,      # type: int
                     eval_batch_size,       # type: int
380
381
                     deterministic,         # type: bool
                     match_mlperf           # type: bool
382
383
                    ):
  # type: (...) -> None
384
385
386
387
388
  """Primary run loop for data file generation."""

  log_msg("Signaling that I am alive.")
  with tf.gfile.Open(cache_paths.subproc_alive, "w") as f:
    f.write("Generation subproc has started.")
389
390
391
392
393
394
395

  @atexit.register
  def remove_alive_file():
    try:
      tf.gfile.Remove(cache_paths.subproc_alive)
    except tf.errors.NotFoundError:
      return  # Main thread has already deleted the entire cache dir.
396
397
398

  log_msg("Entering generation loop.")
  tf.gfile.MakeDirs(cache_paths.train_epoch_dir)
399
  tf.gfile.MakeDirs(cache_paths.eval_data_subdir)
400
401
402
403

  training_shards = [os.path.join(cache_paths.train_shard_subdir, i) for i in
                     tf.gfile.ListDirectory(cache_paths.train_shard_subdir)]

404
405
406
407
408
409
410
  shared_kwargs = dict(
      num_workers=multiprocessing.cpu_count(), cache_paths=cache_paths,
      num_readers=num_readers, num_items=num_items,
      training_shards=training_shards, deterministic=deterministic,
      match_mlperf=match_mlperf
  )

411
412
413
  # Training blocks on the creation of the first epoch, so the num_workers
  # limit is not respected for this invocation
  train_cycle = 0
414
415
416
417
418
419
420
421
422
423
424
  _construct_records(
      is_training=True, train_cycle=train_cycle, num_neg=num_neg,
      num_positives=num_train_positives, epochs_per_cycle=epochs_per_cycle,
      batch_size=train_batch_size, **shared_kwargs)

  # Construct evaluation set.
  shared_kwargs["num_workers"] = num_workers
  _construct_records(
      is_training=False, train_cycle=None, num_neg=rconst.NUM_EVAL_NEGATIVES,
      num_positives=num_users, epochs_per_cycle=1, batch_size=eval_batch_size,
      **shared_kwargs)
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447

  wait_count = 0
  start_time = time.time()
  while True:
    ready_epochs = tf.gfile.ListDirectory(cache_paths.train_epoch_dir)
    if len(ready_epochs) >= rconst.CYCLES_TO_BUFFER:
      wait_count += 1
      sleep_time = max([0, wait_count * 5 - (time.time() - start_time)])
      time.sleep(sleep_time)

      if (wait_count % 10) == 0:
        log_msg("Waited {} times for data to be consumed."
                .format(wait_count))

      if time.time() - start_time > rconst.TIMEOUT_SECONDS:
        log_msg("Waited more than {} seconds. Concluding that this "
                "process is orphaned and exiting gracefully."
                .format(rconst.TIMEOUT_SECONDS))
        sys.exit()

      continue

    train_cycle += 1
448
449
450
451
    _construct_records(
        is_training=True, train_cycle=train_cycle, num_neg=num_neg,
        num_positives=num_train_positives, epochs_per_cycle=epochs_per_cycle,
        batch_size=train_batch_size, **shared_kwargs)
452
453
454
455
456
457

    wait_count = 0
    start_time = time.time()
    gc.collect()


458
def _parse_flagfile(flagfile):
459
  """Fill flags with flagfile written by the main process."""
460
461
  tf.logging.info("Waiting for flagfile to appear at {}..."
                  .format(flagfile))
462
  start_time = time.time()
463
  while not tf.gfile.Exists(flagfile):
464
465
466
467
468
    if time.time() - start_time > rconst.TIMEOUT_SECONDS:
      log_msg("Waited more than {} seconds. Concluding that this "
              "process is orphaned and exiting gracefully."
              .format(rconst.TIMEOUT_SECONDS))
      sys.exit()
469
    time.sleep(1)
470
  tf.logging.info("flagfile found.")
471
472
473
474
475
476
477
478

  # `flags` module opens `flagfile` with `open`, which does not work on
  # google cloud storage etc.
  _, flagfile_temp = tempfile.mkstemp()
  tf.gfile.Copy(flagfile, flagfile_temp, overwrite=True)

  flags.FLAGS([__file__, "--flagfile", flagfile_temp])
  tf.gfile.Remove(flagfile_temp)
479
480


481
def main(_):
482
  global _log_file
483
484
485
  cache_paths = rconst.Paths(
      data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id)

486
487
488
489
490
  flagfile = os.path.join(cache_paths.cache_root, rconst.FLAGFILE)
  _parse_flagfile(flagfile)

  redirect_logs = flags.FLAGS.redirect_logs

491
  log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id)
492
493
  log_path = os.path.join(cache_paths.data_dir, log_file_name)
  if log_path.startswith("gs://") and redirect_logs:
494
495
    fallback_log_file = os.path.join(tempfile.gettempdir(), log_file_name)
    print("Unable to log to {}. Falling back to {}"
496
497
          .format(log_path, fallback_log_file))
    log_path = fallback_log_file
498
499
500

  # This server is generally run in a subprocess.
  if redirect_logs:
501
502
503
    print("Redirecting output of data_async_generation.py process to {}"
          .format(log_path))
    _log_file = open(log_path, "wt")  # Note: not tf.gfile.Open().
504
  try:
505
506
507
508
509
510
511
512
513
514
515
516
    log_msg("sys.argv: {}".format(" ".join(sys.argv)))

    if flags.FLAGS.seed is not None:
      np.random.seed(flags.FLAGS.seed)

    _generation_loop(
        num_workers=flags.FLAGS.num_workers,
        cache_paths=cache_paths,
        num_readers=flags.FLAGS.num_readers,
        num_neg=flags.FLAGS.num_neg,
        num_train_positives=flags.FLAGS.num_train_positives,
        num_items=flags.FLAGS.num_items,
517
        num_users=flags.FLAGS.num_users,
518
519
520
        epochs_per_cycle=flags.FLAGS.epochs_per_cycle,
        train_batch_size=flags.FLAGS.train_batch_size,
        eval_batch_size=flags.FLAGS.eval_batch_size,
521
        deterministic=flags.FLAGS.seed is not None,
522
        match_mlperf=flags.FLAGS.ml_perf,
523
524
525
526
527
528
    )
  except KeyboardInterrupt:
    log_msg("KeyboardInterrupt registered.")
  except:
    traceback.print_exc(file=_log_file)
    raise
529
530
531
532
533
  finally:
    log_msg("Shutting down generation subprocess.")
    sys.stdout.flush()
    sys.stderr.flush()
    if redirect_logs:
534
      _log_file.close()
535
536
537


def define_flags():
538
  """Construct flags for the server."""
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
  flags.DEFINE_integer(name="num_workers", default=multiprocessing.cpu_count(),
                       help="Size of the negative generation worker pool.")
  flags.DEFINE_string(name="data_dir", default=None,
                      help="The data root. (used to construct cache paths.)")
  flags.DEFINE_string(name="cache_id", default=None,
                      help="The cache_id generated in the main process.")
  flags.DEFINE_integer(name="num_readers", default=4,
                       help="Number of reader datasets in training. This sets"
                            "how the epoch files are sharded.")
  flags.DEFINE_integer(name="num_neg", default=None,
                       help="The Number of negative instances to pair with a "
                            "positive instance.")
  flags.DEFINE_integer(name="num_train_positives", default=None,
                       help="The number of positive training examples.")
  flags.DEFINE_integer(name="num_items", default=None,
                       help="Number of items from which to select negatives.")
555
556
  flags.DEFINE_integer(name="num_users", default=None,
                       help="The number of unique users. Used for evaluation.")
557
558
559
560
561
562
563
564
565
566
567
568
  flags.DEFINE_integer(name="epochs_per_cycle", default=1,
                       help="The number of epochs of training data to produce"
                            "at a time.")
  flags.DEFINE_integer(name="train_batch_size", default=None,
                       help="The batch size with which training TFRecords will "
                            "be chunked.")
  flags.DEFINE_integer(name="eval_batch_size", default=None,
                       help="The batch size with which evaluation TFRecords "
                            "will be chunked.")
  flags.DEFINE_boolean(name="redirect_logs", default=False,
                       help="Catch logs and write them to a file. "
                            "(Useful if this is run as a subprocess)")
569
570
  flags.DEFINE_boolean(name="use_tf_logging", default=False,
                       help="Use tf.logging instead of log file.")
571
572
573
  flags.DEFINE_integer(name="seed", default=None,
                       help="NumPy random seed to set at startup. If not "
                            "specified, a seed will not be set.")
574
575
  flags.DEFINE_boolean(name="ml_perf", default=None,
                       help="Match MLPerf. See ncf_main.py for details.")
576

577
  flags.mark_flags_as_required(["data_dir", "cache_id"])
578
579
580
581

if __name__ == "__main__":
  define_flags()
  absl_app.run(main)