data_async_generation.py 22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Asynchronously generate TFRecords files for NCF."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import atexit
import contextlib
import datetime
import gc
import multiprocessing
import json
import os
import pickle
import signal
import sys
import tempfile
import time
import timeit
import traceback
import typing

import numpy as np
import tensorflow as tf

from absl import app as absl_app
from absl import flags

from official.datasets import movielens
from official.recommendation import constants as rconst
from official.recommendation import stat_utils
Shawn Wang's avatar
Shawn Wang committed
46
from official.recommendation import popen_helper
47
48


49
50
51
_log_file = None


52
53
def log_msg(msg):
  """Include timestamp info when logging messages to a file."""
54
  if flags.FLAGS.use_tf_logging:
55
56
57
    tf.logging.info(msg)
    return

58
59
  if flags.FLAGS.redirect_logs:
    timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
60
    print("[{}] {}".format(timestamp, msg), file=_log_file)
61
  else:
62
63
64
    print(msg, file=_log_file)
  if _log_file:
    _log_file.flush()
65
66
67
68
69
70


def get_cycle_folder_name(i):
  return "cycle_{}".format(str(i).zfill(5))


71
def _process_shard(args):
72
  # type: ((str, int, int, int, bool)) -> (np.ndarray, np.ndarray, np.ndarray)
73
74
75
76
77
78
79
  """Read a shard of training data and return training vectors.

  Args:
    shard_path: The filepath of the positive instance training shard.
    num_items: The cardinality of the item set.
    num_neg: The number of negatives to generate per positive example.
    seed: Random seed to be used when generating negatives.
80
81
    is_training: Generate training (True) or eval (False) data.
    match_mlperf: Match the MLPerf reference behavior
82
  """
83
  shard_path, num_items, num_neg, seed, is_training, match_mlperf = args
84
  np.random.seed(seed)
85
86
87
88
89
90
91
92
93
94
95
96

  # The choice to store the training shards in files rather than in memory
  # is motivated by the fact that multiprocessing serializes arguments,
  # transmits them to map workers, and then deserializes them. By storing the
  # training shards in files, the serialization work only needs to be done once.
  #
  # A similar effect could be achieved by simply holding pickled bytes in
  # memory, however the processing is not I/O bound and is therefore
  # unnecessary.
  with tf.gfile.Open(shard_path, "rb") as f:
    shard = pickle.load(f)

97
98
99
100
101
102
103
104
  users = shard[rconst.TRAIN_KEY][movielens.USER_COLUMN]
  items = shard[rconst.TRAIN_KEY][movielens.ITEM_COLUMN]

  if not is_training:
    # For eval, there is one positive which was held out from the training set.
    test_positive_dict = dict(zip(
        shard[rconst.EVAL_KEY][movielens.USER_COLUMN],
        shard[rconst.EVAL_KEY][movielens.ITEM_COLUMN]))
105
106
107
108
109
110
111
112
113
114

  delta = users[1:] - users[:-1]
  boundaries = ([0] + (np.argwhere(delta)[:, 0] + 1).tolist() +
                [users.shape[0]])

  user_blocks = []
  item_blocks = []
  label_blocks = []
  for i in range(len(boundaries) - 1):
    assert len(set(users[boundaries[i]:boundaries[i+1]])) == 1
115
116
    current_user = users[boundaries[i]]

117
118
119
120
121
    positive_items = items[boundaries[i]:boundaries[i+1]]
    positive_set = set(positive_items)
    if positive_items.shape[0] != len(positive_set):
      raise ValueError("Duplicate entries detected.")

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    if is_training:
      n_pos = len(positive_set)
      negatives = stat_utils.sample_with_exclusion(
          num_items, positive_set, n_pos * num_neg, replacement=True)

    else:
      if not match_mlperf:
        # The mlperf reference allows the holdout item to appear as a negative.
        # Including it in the positive set makes the eval more stringent,
        # because an appearance of the test item would be removed by
        # deduplication rules. (Effectively resulting in a minute reduction of
        # NUM_EVAL_NEGATIVES)
        positive_set.add(test_positive_dict[current_user])

      negatives = stat_utils.sample_with_exclusion(
          num_items, positive_set, num_neg, replacement=match_mlperf)
      positive_set = [test_positive_dict[current_user]]
      n_pos = len(positive_set)
      assert n_pos == 1

    user_blocks.append(current_user * np.ones(
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        (n_pos * (1 + num_neg),), dtype=np.int32))
    item_blocks.append(
        np.array(list(positive_set) + negatives, dtype=np.uint16))
    labels_for_user = np.zeros((n_pos * (1 + num_neg),), dtype=np.int8)
    labels_for_user[:n_pos] = 1
    label_blocks.append(labels_for_user)

  users_out = np.concatenate(user_blocks)
  items_out = np.concatenate(item_blocks)
  labels_out = np.concatenate(label_blocks)

  assert users_out.shape == items_out.shape == labels_out.shape
  return users_out, items_out, labels_out


158
def _construct_record(users, items, labels=None, dupe_mask=None):
159
160
161
162
163
164
165
166
167
168
169
  """Convert NumPy arrays into a TFRecords entry."""
  feature_dict = {
      movielens.USER_COLUMN: tf.train.Feature(
          bytes_list=tf.train.BytesList(value=[memoryview(users).tobytes()])),
      movielens.ITEM_COLUMN: tf.train.Feature(
          bytes_list=tf.train.BytesList(value=[memoryview(items).tobytes()])),
  }
  if labels is not None:
    feature_dict["labels"] = tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[memoryview(labels).tobytes()]))

170
171
172
173
  if dupe_mask is not None:
    feature_dict[rconst.DUPLICATE_MASK] = tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[memoryview(dupe_mask).tobytes()]))

174
175
176
177
178
179
180
181
182
183
184
185
  return tf.train.Example(
      features=tf.train.Features(feature=feature_dict)).SerializeToString()


def sigint_handler(signal, frame):
  log_msg("Shutting down worker.")


def init_worker():
  signal.signal(signal.SIGINT, sigint_handler)


186
187
188
def _construct_records(
    is_training,          # type: bool
    train_cycle,          # type: typing.Optional[int]
189
190
191
192
    num_workers,          # type: int
    cache_paths,          # type: rconst.Paths
    num_readers,          # type: int
    num_neg,              # type: int
193
    num_positives,        # type: int
194
195
    num_items,            # type: int
    epochs_per_cycle,     # type: int
196
    batch_size,           # type: int
197
    training_shards,      # type: typing.List[str]
198
199
    deterministic=False,  # type: bool
    match_mlperf=False    # type: bool
200
201
202
203
    ):
  """Generate false negatives and write TFRecords files.

  Args:
204
    is_training: Are training records (True) or eval records (False) created.
205
206
207
208
    train_cycle: Integer of which cycle the generated data is for.
    num_workers: Number of multiprocessing workers to use for negative
      generation.
    cache_paths: Paths object with information of where to write files.
209
210
211
    num_readers: The number of reader datasets in the input_fn. This number is
      approximate; fewer shards will be created if not all shards are assigned
      batches. This can occur due to discretization in the assignment process.
212
    num_neg: The number of false negatives per positive example.
213
    num_positives: The number of positive examples. This value is used
214
215
216
217
      to pre-allocate arrays while the imap is still running. (NumPy does not
      allow dynamic arrays.)
    num_items: The cardinality of the item set.
    epochs_per_cycle: The number of epochs worth of data to construct.
218
    batch_size: The expected batch size used during training. This is used
219
220
221
222
223
      to properly batch data when writing TFRecords.
    training_shards: The picked positive examples from which to generate
      negatives.
  """
  st = timeit.default_timer()
224
225
226
227
228
229
230

  if not is_training:
    # Later logic assumes that all items for a given user are in the same batch.
    assert not batch_size % (rconst.NUM_EVAL_NEGATIVES + 1)
    assert num_neg == rconst.NUM_EVAL_NEGATIVES

  assert epochs_per_cycle == 1 or is_training
231
  num_workers = min([num_workers, len(training_shards) * epochs_per_cycle])
232
233
234
235
236
237
238

  num_pts = num_positives * (1 + num_neg)

  # Equivalent to `int(ceil(num_pts / batch_size)) * batch_size`, but without
  # precision concerns
  num_pts_with_padding = (num_pts + batch_size - 1) // batch_size * batch_size
  num_padding = num_pts_with_padding - num_pts
239

240
241
  # We choose a different random seed for each process, so that the processes
  # will not all choose the same random numbers.
242
  process_seeds = [stat_utils.random_int32()
243
                   for _ in training_shards * epochs_per_cycle]
244
245
246
  map_args = [
      (shard, num_items, num_neg, process_seeds[i], is_training, match_mlperf)
      for i, shard in enumerate(training_shards * epochs_per_cycle)]
247

248
  with popen_helper.get_pool(num_workers, init_worker) as pool:
249
250
    map_fn = pool.imap if deterministic else pool.imap_unordered  # pylint: disable=no-member
    data_generator = map_fn(_process_shard, map_args)
251
    data = [
252
253
254
        np.zeros(shape=(num_pts_with_padding,), dtype=np.int32) - 1,
        np.zeros(shape=(num_pts_with_padding,), dtype=np.uint16),
        np.zeros(shape=(num_pts_with_padding,), dtype=np.int8),
255
256
    ]

257
258
259
260
261
262
263
264
    # Training data is shuffled. Evaluation data MUST not be shuffled.
    # Downstream processing depends on the fact that evaluation data for a given
    # user is grouped within a batch.
    if is_training:
      index_destinations = np.random.permutation(num_pts)
    else:
      index_destinations = np.arange(num_pts)

265
266
267
268
269
270
271
272
    start_ind = 0
    for data_segment in data_generator:
      n_in_segment = data_segment[0].shape[0]
      dest = index_destinations[start_ind:start_ind + n_in_segment]
      start_ind += n_in_segment
      for i in range(3):
        data[i][dest] = data_segment[i]

273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
  assert np.sum(data[0] == -1) == num_padding

  if is_training:
    if num_padding:
      # In order to have a full batch, randomly include points from earlier in
      # the batch.
      pad_sample_indices = np.random.randint(
          low=0, high=num_pts, size=(num_padding,))
      dest = np.arange(start=start_ind, stop=start_ind + num_padding)
      start_ind += num_padding
      for i in range(3):
        data[i][dest] = data[i][pad_sample_indices]
  else:
    # For Evaluation, padding is all zeros. The evaluation input_fn knows how
    # to interpret and discard the zero padded entries.
    data[0][num_pts:] = 0
289

290
    # Check that no points were overlooked.
291

292
293
294
  assert not np.sum(data[0] == -1)

  batches_per_file = np.ceil(num_pts_with_padding / batch_size / num_readers)
295
296
297
298
299
300
301
302
  current_file_id = -1
  current_batch_id = -1
  batches_by_file = [[] for _ in range(num_readers)]

  while True:
    current_batch_id += 1
    if (current_batch_id % batches_per_file) == 0:
      current_file_id += 1
303
304
305
306
307
308
309

    start_ind = current_batch_id * batch_size
    end_ind = start_ind + batch_size
    if end_ind > num_pts_with_padding:
      if start_ind != num_pts_with_padding:
        raise ValueError("Batch padding does not line up")
      break
310
311
    batches_by_file[current_file_id].append(current_batch_id)

312
313
314
315
  # Drop shards which were not assigned batches
  batches_by_file = [i for i in batches_by_file if i]
  num_readers = len(batches_by_file)

316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
  if is_training:
    # Empirically it is observed that placing the batch with repeated values at
    # the start rather than the end improves convergence.
    batches_by_file[0][0], batches_by_file[-1][-1] = \
      batches_by_file[-1][-1], batches_by_file[0][0]

  if is_training:
    template = rconst.TRAIN_RECORD_TEMPLATE
    record_dir = os.path.join(cache_paths.train_epoch_dir,
                              get_cycle_folder_name(train_cycle))
    tf.gfile.MakeDirs(record_dir)
  else:
    template = rconst.EVAL_RECORD_TEMPLATE
    record_dir = cache_paths.eval_data_subdir

331
332
  batch_count = 0
  for i in range(num_readers):
333
    fpath = os.path.join(record_dir, template.format(i))
334
335
336
    log_msg("Writing {}".format(fpath))
    with tf.python_io.TFRecordWriter(fpath) as writer:
      for j in batches_by_file[i]:
337
338
339
        start_ind = j * batch_size
        end_ind = start_ind + batch_size
        record_kwargs = dict(
340
341
342
343
            users=data[0][start_ind:end_ind],
            items=data[1][start_ind:end_ind],
        )

344
345
346
347
348
349
        if is_training:
          record_kwargs["labels"] = data[2][start_ind:end_ind]
        else:
          record_kwargs["dupe_mask"] = stat_utils.mask_duplicates(
              record_kwargs["items"].reshape(-1, num_neg + 1),
              axis=1).flatten().astype(np.int8)
350

351
        batch_bytes = _construct_record(**record_kwargs)
352

353
354
        writer.write(batch_bytes)
        batch_count += 1
355

356
357
358
359
360
  # We write to a temp file then atomically rename it to the final file, because
  # writing directly to the final file can cause the main process to read a
  # partially written JSON file.
  ready_file_temp = os.path.join(record_dir, rconst.READY_FILE_TEMP)
  with tf.gfile.Open(ready_file_temp, "w") as f:
361
    json.dump({
362
        "batch_size": batch_size,
363
364
        "batch_count": batch_count,
    }, f)
365
366
  ready_file = os.path.join(record_dir, rconst.READY_FILE)
  tf.gfile.Rename(ready_file_temp, ready_file)
367

368
369
370
371
372
373
  if is_training:
    log_msg("Cycle {} complete. Total time: {:.1f} seconds"
            .format(train_cycle, timeit.default_timer() - st))
  else:
    log_msg("Eval construction complete. Total time: {:.1f} seconds"
            .format(timeit.default_timer() - st))
374
375


376
377
378
379
380
381
def _generation_loop(num_workers,           # type: int
                     cache_paths,           # type: rconst.Paths
                     num_readers,           # type: int
                     num_neg,               # type: int
                     num_train_positives,   # type: int
                     num_items,             # type: int
382
                     num_users,             # type: int
383
384
385
                     epochs_per_cycle,      # type: int
                     train_batch_size,      # type: int
                     eval_batch_size,       # type: int
386
387
                     deterministic,         # type: bool
                     match_mlperf           # type: bool
388
389
                    ):
  # type: (...) -> None
390
391
392
393
394
  """Primary run loop for data file generation."""

  log_msg("Signaling that I am alive.")
  with tf.gfile.Open(cache_paths.subproc_alive, "w") as f:
    f.write("Generation subproc has started.")
395
396
397
398
399
400
401

  @atexit.register
  def remove_alive_file():
    try:
      tf.gfile.Remove(cache_paths.subproc_alive)
    except tf.errors.NotFoundError:
      return  # Main thread has already deleted the entire cache dir.
402
403
404

  log_msg("Entering generation loop.")
  tf.gfile.MakeDirs(cache_paths.train_epoch_dir)
405
  tf.gfile.MakeDirs(cache_paths.eval_data_subdir)
406
407
408
409

  training_shards = [os.path.join(cache_paths.train_shard_subdir, i) for i in
                     tf.gfile.ListDirectory(cache_paths.train_shard_subdir)]

410
411
412
413
414
415
416
  shared_kwargs = dict(
      num_workers=multiprocessing.cpu_count(), cache_paths=cache_paths,
      num_readers=num_readers, num_items=num_items,
      training_shards=training_shards, deterministic=deterministic,
      match_mlperf=match_mlperf
  )

417
418
419
  # Training blocks on the creation of the first epoch, so the num_workers
  # limit is not respected for this invocation
  train_cycle = 0
420
421
422
423
424
425
426
427
428
429
430
  _construct_records(
      is_training=True, train_cycle=train_cycle, num_neg=num_neg,
      num_positives=num_train_positives, epochs_per_cycle=epochs_per_cycle,
      batch_size=train_batch_size, **shared_kwargs)

  # Construct evaluation set.
  shared_kwargs["num_workers"] = num_workers
  _construct_records(
      is_training=False, train_cycle=None, num_neg=rconst.NUM_EVAL_NEGATIVES,
      num_positives=num_users, epochs_per_cycle=1, batch_size=eval_batch_size,
      **shared_kwargs)
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453

  wait_count = 0
  start_time = time.time()
  while True:
    ready_epochs = tf.gfile.ListDirectory(cache_paths.train_epoch_dir)
    if len(ready_epochs) >= rconst.CYCLES_TO_BUFFER:
      wait_count += 1
      sleep_time = max([0, wait_count * 5 - (time.time() - start_time)])
      time.sleep(sleep_time)

      if (wait_count % 10) == 0:
        log_msg("Waited {} times for data to be consumed."
                .format(wait_count))

      if time.time() - start_time > rconst.TIMEOUT_SECONDS:
        log_msg("Waited more than {} seconds. Concluding that this "
                "process is orphaned and exiting gracefully."
                .format(rconst.TIMEOUT_SECONDS))
        sys.exit()

      continue

    train_cycle += 1
454
455
456
457
    _construct_records(
        is_training=True, train_cycle=train_cycle, num_neg=num_neg,
        num_positives=num_train_positives, epochs_per_cycle=epochs_per_cycle,
        batch_size=train_batch_size, **shared_kwargs)
458
459
460
461
462
463

    wait_count = 0
    start_time = time.time()
    gc.collect()


464
def _parse_flagfile(flagfile):
465
  """Fill flags with flagfile written by the main process."""
466
467
  tf.logging.info("Waiting for flagfile to appear at {}..."
                  .format(flagfile))
468
  start_time = time.time()
469
  while not tf.gfile.Exists(flagfile):
470
471
472
473
474
    if time.time() - start_time > rconst.TIMEOUT_SECONDS:
      log_msg("Waited more than {} seconds. Concluding that this "
              "process is orphaned and exiting gracefully."
              .format(rconst.TIMEOUT_SECONDS))
      sys.exit()
475
    time.sleep(1)
476
  tf.logging.info("flagfile found.")
477
478
479
480
481
482
483
484

  # `flags` module opens `flagfile` with `open`, which does not work on
  # google cloud storage etc.
  _, flagfile_temp = tempfile.mkstemp()
  tf.gfile.Copy(flagfile, flagfile_temp, overwrite=True)

  flags.FLAGS([__file__, "--flagfile", flagfile_temp])
  tf.gfile.Remove(flagfile_temp)
485
486


487
def main(_):
488
  global _log_file
489
490
491
  cache_paths = rconst.Paths(
      data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id)

492
493
494
495
496
  flagfile = os.path.join(cache_paths.cache_root, rconst.FLAGFILE)
  _parse_flagfile(flagfile)

  redirect_logs = flags.FLAGS.redirect_logs

497
  log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id)
498
499
  log_path = os.path.join(cache_paths.data_dir, log_file_name)
  if log_path.startswith("gs://") and redirect_logs:
500
501
    fallback_log_file = os.path.join(tempfile.gettempdir(), log_file_name)
    print("Unable to log to {}. Falling back to {}"
502
503
          .format(log_path, fallback_log_file))
    log_path = fallback_log_file
504
505
506

  # This server is generally run in a subprocess.
  if redirect_logs:
507
508
509
    print("Redirecting output of data_async_generation.py process to {}"
          .format(log_path))
    _log_file = open(log_path, "wt")  # Note: not tf.gfile.Open().
510
  try:
511
512
513
514
515
516
517
518
519
520
521
522
    log_msg("sys.argv: {}".format(" ".join(sys.argv)))

    if flags.FLAGS.seed is not None:
      np.random.seed(flags.FLAGS.seed)

    _generation_loop(
        num_workers=flags.FLAGS.num_workers,
        cache_paths=cache_paths,
        num_readers=flags.FLAGS.num_readers,
        num_neg=flags.FLAGS.num_neg,
        num_train_positives=flags.FLAGS.num_train_positives,
        num_items=flags.FLAGS.num_items,
523
        num_users=flags.FLAGS.num_users,
524
525
526
        epochs_per_cycle=flags.FLAGS.epochs_per_cycle,
        train_batch_size=flags.FLAGS.train_batch_size,
        eval_batch_size=flags.FLAGS.eval_batch_size,
527
        deterministic=flags.FLAGS.seed is not None,
528
        match_mlperf=flags.FLAGS.ml_perf,
529
530
531
532
533
534
    )
  except KeyboardInterrupt:
    log_msg("KeyboardInterrupt registered.")
  except:
    traceback.print_exc(file=_log_file)
    raise
535
536
537
538
539
  finally:
    log_msg("Shutting down generation subprocess.")
    sys.stdout.flush()
    sys.stderr.flush()
    if redirect_logs:
540
      _log_file.close()
541
542
543


def define_flags():
544
  """Construct flags for the server."""
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
  flags.DEFINE_integer(name="num_workers", default=multiprocessing.cpu_count(),
                       help="Size of the negative generation worker pool.")
  flags.DEFINE_string(name="data_dir", default=None,
                      help="The data root. (used to construct cache paths.)")
  flags.DEFINE_string(name="cache_id", default=None,
                      help="The cache_id generated in the main process.")
  flags.DEFINE_integer(name="num_readers", default=4,
                       help="Number of reader datasets in training. This sets"
                            "how the epoch files are sharded.")
  flags.DEFINE_integer(name="num_neg", default=None,
                       help="The Number of negative instances to pair with a "
                            "positive instance.")
  flags.DEFINE_integer(name="num_train_positives", default=None,
                       help="The number of positive training examples.")
  flags.DEFINE_integer(name="num_items", default=None,
                       help="Number of items from which to select negatives.")
561
562
  flags.DEFINE_integer(name="num_users", default=None,
                       help="The number of unique users. Used for evaluation.")
563
564
565
566
567
568
569
570
571
572
573
574
  flags.DEFINE_integer(name="epochs_per_cycle", default=1,
                       help="The number of epochs of training data to produce"
                            "at a time.")
  flags.DEFINE_integer(name="train_batch_size", default=None,
                       help="The batch size with which training TFRecords will "
                            "be chunked.")
  flags.DEFINE_integer(name="eval_batch_size", default=None,
                       help="The batch size with which evaluation TFRecords "
                            "will be chunked.")
  flags.DEFINE_boolean(name="redirect_logs", default=False,
                       help="Catch logs and write them to a file. "
                            "(Useful if this is run as a subprocess)")
575
576
  flags.DEFINE_boolean(name="use_tf_logging", default=False,
                       help="Use tf.logging instead of log file.")
577
578
579
  flags.DEFINE_integer(name="seed", default=None,
                       help="NumPy random seed to set at startup. If not "
                            "specified, a seed will not be set.")
580
581
  flags.DEFINE_boolean(name="ml_perf", default=None,
                       help="Match MLPerf. See ncf_main.py for details.")
582

583
  flags.mark_flags_as_required(["data_dir", "cache_id"])
584
585
586
587

if __name__ == "__main__":
  define_flags()
  absl_app.run(main)