"tensorflow_models/LICENSE" did not exist on "ed7d404f4780565884d2344de6aa3bce59bb5a2c"
data_async_generation.py 24.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Asynchronously generate TFRecords files for NCF."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import atexit
import contextlib
import datetime
import gc
import multiprocessing
import json
import os
import pickle
import signal
import sys
import tempfile
import time
import timeit
import traceback
import typing

import numpy as np
import tensorflow as tf

from absl import app as absl_app
from absl import flags

from official.datasets import movielens
from official.recommendation import constants as rconst
from official.recommendation import stat_utils
Shawn Wang's avatar
Shawn Wang committed
46
from official.recommendation import popen_helper
47
from official.utils.logs import mlperf_helper
48
49


50
51
52
_log_file = None


53
54
def log_msg(msg):
  """Include timestamp info when logging messages to a file."""
55
  if flags.FLAGS.use_tf_logging:
56
57
58
    tf.logging.info(msg)
    return

59
60
  if flags.FLAGS.redirect_logs:
    timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
61
    print("[{}] {}".format(timestamp, msg), file=_log_file)
62
  else:
63
64
65
    print(msg, file=_log_file)
  if _log_file:
    _log_file.flush()
66
67
68
69
70
71


def get_cycle_folder_name(i):
  return "cycle_{}".format(str(i).zfill(5))


72
def _process_shard(args):
73
  # type: ((str, int, int, int, bool)) -> (np.ndarray, np.ndarray, np.ndarray)
74
75
76
77
78
79
80
  """Read a shard of training data and return training vectors.

  Args:
    shard_path: The filepath of the positive instance training shard.
    num_items: The cardinality of the item set.
    num_neg: The number of negatives to generate per positive example.
    seed: Random seed to be used when generating negatives.
81
82
    is_training: Generate training (True) or eval (False) data.
    match_mlperf: Match the MLPerf reference behavior
83
  """
84
  shard_path, num_items, num_neg, seed, is_training, match_mlperf = args
85
  np.random.seed(seed)
86
87
88
89
90
91
92
93
94
95
96
97

  # The choice to store the training shards in files rather than in memory
  # is motivated by the fact that multiprocessing serializes arguments,
  # transmits them to map workers, and then deserializes them. By storing the
  # training shards in files, the serialization work only needs to be done once.
  #
  # A similar effect could be achieved by simply holding pickled bytes in
  # memory, however the processing is not I/O bound and is therefore
  # unnecessary.
  with tf.gfile.Open(shard_path, "rb") as f:
    shard = pickle.load(f)

98
99
100
101
102
103
104
105
  users = shard[rconst.TRAIN_KEY][movielens.USER_COLUMN]
  items = shard[rconst.TRAIN_KEY][movielens.ITEM_COLUMN]

  if not is_training:
    # For eval, there is one positive which was held out from the training set.
    test_positive_dict = dict(zip(
        shard[rconst.EVAL_KEY][movielens.USER_COLUMN],
        shard[rconst.EVAL_KEY][movielens.ITEM_COLUMN]))
106
107
108
109
110
111
112
113
114
115

  delta = users[1:] - users[:-1]
  boundaries = ([0] + (np.argwhere(delta)[:, 0] + 1).tolist() +
                [users.shape[0]])

  user_blocks = []
  item_blocks = []
  label_blocks = []
  for i in range(len(boundaries) - 1):
    assert len(set(users[boundaries[i]:boundaries[i+1]])) == 1
116
117
    current_user = users[boundaries[i]]

118
119
120
121
122
    positive_items = items[boundaries[i]:boundaries[i+1]]
    positive_set = set(positive_items)
    if positive_items.shape[0] != len(positive_set):
      raise ValueError("Duplicate entries detected.")

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    if is_training:
      n_pos = len(positive_set)
      negatives = stat_utils.sample_with_exclusion(
          num_items, positive_set, n_pos * num_neg, replacement=True)

    else:
      if not match_mlperf:
        # The mlperf reference allows the holdout item to appear as a negative.
        # Including it in the positive set makes the eval more stringent,
        # because an appearance of the test item would be removed by
        # deduplication rules. (Effectively resulting in a minute reduction of
        # NUM_EVAL_NEGATIVES)
        positive_set.add(test_positive_dict[current_user])

      negatives = stat_utils.sample_with_exclusion(
          num_items, positive_set, num_neg, replacement=match_mlperf)
      positive_set = [test_positive_dict[current_user]]
      n_pos = len(positive_set)
      assert n_pos == 1

    user_blocks.append(current_user * np.ones(
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        (n_pos * (1 + num_neg),), dtype=np.int32))
    item_blocks.append(
        np.array(list(positive_set) + negatives, dtype=np.uint16))
    labels_for_user = np.zeros((n_pos * (1 + num_neg),), dtype=np.int8)
    labels_for_user[:n_pos] = 1
    label_blocks.append(labels_for_user)

  users_out = np.concatenate(user_blocks)
  items_out = np.concatenate(item_blocks)
  labels_out = np.concatenate(label_blocks)

  assert users_out.shape == items_out.shape == labels_out.shape
  return users_out, items_out, labels_out


159
def _construct_record(users, items, labels=None, dupe_mask=None):
160
161
162
163
164
165
166
167
168
169
170
  """Convert NumPy arrays into a TFRecords entry."""
  feature_dict = {
      movielens.USER_COLUMN: tf.train.Feature(
          bytes_list=tf.train.BytesList(value=[memoryview(users).tobytes()])),
      movielens.ITEM_COLUMN: tf.train.Feature(
          bytes_list=tf.train.BytesList(value=[memoryview(items).tobytes()])),
  }
  if labels is not None:
    feature_dict["labels"] = tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[memoryview(labels).tobytes()]))

171
172
173
174
  if dupe_mask is not None:
    feature_dict[rconst.DUPLICATE_MASK] = tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[memoryview(dupe_mask).tobytes()]))

175
176
177
178
179
180
181
182
183
184
185
186
  return tf.train.Example(
      features=tf.train.Features(feature=feature_dict)).SerializeToString()


def sigint_handler(signal, frame):
  log_msg("Shutting down worker.")


def init_worker():
  signal.signal(signal.SIGINT, sigint_handler)


187
188
189
def _construct_records(
    is_training,          # type: bool
    train_cycle,          # type: typing.Optional[int]
190
191
192
193
    num_workers,          # type: int
    cache_paths,          # type: rconst.Paths
    num_readers,          # type: int
    num_neg,              # type: int
194
    num_positives,        # type: int
195
196
    num_items,            # type: int
    epochs_per_cycle,     # type: int
197
    batch_size,           # type: int
198
    training_shards,      # type: typing.List[str]
199
200
    deterministic=False,  # type: bool
    match_mlperf=False    # type: bool
201
202
203
204
    ):
  """Generate false negatives and write TFRecords files.

  Args:
205
    is_training: Are training records (True) or eval records (False) created.
206
207
208
209
    train_cycle: Integer of which cycle the generated data is for.
    num_workers: Number of multiprocessing workers to use for negative
      generation.
    cache_paths: Paths object with information of where to write files.
210
211
212
    num_readers: The number of reader datasets in the input_fn. This number is
      approximate; fewer shards will be created if not all shards are assigned
      batches. This can occur due to discretization in the assignment process.
213
    num_neg: The number of false negatives per positive example.
214
    num_positives: The number of positive examples. This value is used
215
216
217
218
      to pre-allocate arrays while the imap is still running. (NumPy does not
      allow dynamic arrays.)
    num_items: The cardinality of the item set.
    epochs_per_cycle: The number of epochs worth of data to construct.
219
    batch_size: The expected batch size used during training. This is used
220
221
222
223
224
      to properly batch data when writing TFRecords.
    training_shards: The picked positive examples from which to generate
      negatives.
  """
  st = timeit.default_timer()
225

226
227
228
229
230
231
232
233
234
235
  if is_training:
    mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_STEP_TRAIN_NEG_GEN)
    mlperf_helper.ncf_print(
        key=mlperf_helper.TAGS.INPUT_HP_NUM_NEG, value=num_neg)

    # set inside _process_shard()
    mlperf_helper.ncf_print(
        key=mlperf_helper.TAGS.INPUT_HP_SAMPLE_TRAIN_REPLACEMENT, value=True)

  else:
236
237
238
239
    # Later logic assumes that all items for a given user are in the same batch.
    assert not batch_size % (rconst.NUM_EVAL_NEGATIVES + 1)
    assert num_neg == rconst.NUM_EVAL_NEGATIVES

240
241
242
243
244
    mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_STEP_EVAL_NEG_GEN)

    mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_HP_NUM_USERS,
                            value=num_positives)

245
  assert epochs_per_cycle == 1 or is_training
246
  num_workers = min([num_workers, len(training_shards) * epochs_per_cycle])
247
248
249
250
251
252
253

  num_pts = num_positives * (1 + num_neg)

  # Equivalent to `int(ceil(num_pts / batch_size)) * batch_size`, but without
  # precision concerns
  num_pts_with_padding = (num_pts + batch_size - 1) // batch_size * batch_size
  num_padding = num_pts_with_padding - num_pts
254

255
256
  # We choose a different random seed for each process, so that the processes
  # will not all choose the same random numbers.
257
  process_seeds = [stat_utils.random_int32()
258
                   for _ in training_shards * epochs_per_cycle]
259
260
261
  map_args = [
      (shard, num_items, num_neg, process_seeds[i], is_training, match_mlperf)
      for i, shard in enumerate(training_shards * epochs_per_cycle)]
262

263
  with popen_helper.get_pool(num_workers, init_worker) as pool:
264
265
    map_fn = pool.imap if deterministic else pool.imap_unordered  # pylint: disable=no-member
    data_generator = map_fn(_process_shard, map_args)
266
    data = [
267
268
269
        np.zeros(shape=(num_pts_with_padding,), dtype=np.int32) - 1,
        np.zeros(shape=(num_pts_with_padding,), dtype=np.uint16),
        np.zeros(shape=(num_pts_with_padding,), dtype=np.int8),
270
271
    ]

272
273
274
275
276
    # Training data is shuffled. Evaluation data MUST not be shuffled.
    # Downstream processing depends on the fact that evaluation data for a given
    # user is grouped within a batch.
    if is_training:
      index_destinations = np.random.permutation(num_pts)
277
      mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_ORDER)
278
279
280
    else:
      index_destinations = np.arange(num_pts)

281
282
283
284
285
286
287
288
    start_ind = 0
    for data_segment in data_generator:
      n_in_segment = data_segment[0].shape[0]
      dest = index_destinations[start_ind:start_ind + n_in_segment]
      start_ind += n_in_segment
      for i in range(3):
        data[i][dest] = data_segment[i]

289
290
291
292
293
294
  assert np.sum(data[0] == -1) == num_padding

  if is_training:
    if num_padding:
      # In order to have a full batch, randomly include points from earlier in
      # the batch.
295
296

      mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_ORDER)
297
298
299
300
301
302
303
304
305
306
      pad_sample_indices = np.random.randint(
          low=0, high=num_pts, size=(num_padding,))
      dest = np.arange(start=start_ind, stop=start_ind + num_padding)
      start_ind += num_padding
      for i in range(3):
        data[i][dest] = data[i][pad_sample_indices]
  else:
    # For Evaluation, padding is all zeros. The evaluation input_fn knows how
    # to interpret and discard the zero padded entries.
    data[0][num_pts:] = 0
307

308
  # Check that no points were overlooked.
309
310
  assert not np.sum(data[0] == -1)

311
312
313
314
315
316
317
318
319
320
321
  if is_training:
    # The number of points is slightly larger than num_pts due to padding.
    mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_SIZE,
                            value=int(data[0].shape[0]))
    mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_BATCH_SIZE,
                            value=batch_size)
  else:
    # num_pts is logged instead of int(data[0].shape[0]), because the size
    # of the data vector includes zero pads which are ignored.
    mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_SIZE, value=num_pts)

322
  batches_per_file = np.ceil(num_pts_with_padding / batch_size / num_readers)
323
324
325
326
327
328
329
330
  current_file_id = -1
  current_batch_id = -1
  batches_by_file = [[] for _ in range(num_readers)]

  while True:
    current_batch_id += 1
    if (current_batch_id % batches_per_file) == 0:
      current_file_id += 1
331
332
333
334
335
336
337

    start_ind = current_batch_id * batch_size
    end_ind = start_ind + batch_size
    if end_ind > num_pts_with_padding:
      if start_ind != num_pts_with_padding:
        raise ValueError("Batch padding does not line up")
      break
338
339
    batches_by_file[current_file_id].append(current_batch_id)

340
341
342
343
  # Drop shards which were not assigned batches
  batches_by_file = [i for i in batches_by_file if i]
  num_readers = len(batches_by_file)

344
345
346
  if is_training:
    # Empirically it is observed that placing the batch with repeated values at
    # the start rather than the end improves convergence.
347
    mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_ORDER)
348
349
350
351
352
353
354
355
356
357
358
359
    batches_by_file[0][0], batches_by_file[-1][-1] = \
      batches_by_file[-1][-1], batches_by_file[0][0]

  if is_training:
    template = rconst.TRAIN_RECORD_TEMPLATE
    record_dir = os.path.join(cache_paths.train_epoch_dir,
                              get_cycle_folder_name(train_cycle))
    tf.gfile.MakeDirs(record_dir)
  else:
    template = rconst.EVAL_RECORD_TEMPLATE
    record_dir = cache_paths.eval_data_subdir

360
361
  batch_count = 0
  for i in range(num_readers):
362
    fpath = os.path.join(record_dir, template.format(i))
363
364
365
    log_msg("Writing {}".format(fpath))
    with tf.python_io.TFRecordWriter(fpath) as writer:
      for j in batches_by_file[i]:
366
367
368
        start_ind = j * batch_size
        end_ind = start_ind + batch_size
        record_kwargs = dict(
369
370
371
372
            users=data[0][start_ind:end_ind],
            items=data[1][start_ind:end_ind],
        )

373
374
375
376
377
378
        if is_training:
          record_kwargs["labels"] = data[2][start_ind:end_ind]
        else:
          record_kwargs["dupe_mask"] = stat_utils.mask_duplicates(
              record_kwargs["items"].reshape(-1, num_neg + 1),
              axis=1).flatten().astype(np.int8)
379

380
        batch_bytes = _construct_record(**record_kwargs)
381

382
383
        writer.write(batch_bytes)
        batch_count += 1
384

385
386
387
388
389
  # We write to a temp file then atomically rename it to the final file, because
  # writing directly to the final file can cause the main process to read a
  # partially written JSON file.
  ready_file_temp = os.path.join(record_dir, rconst.READY_FILE_TEMP)
  with tf.gfile.Open(ready_file_temp, "w") as f:
390
    json.dump({
391
        "batch_size": batch_size,
392
393
        "batch_count": batch_count,
    }, f)
394
395
  ready_file = os.path.join(record_dir, rconst.READY_FILE)
  tf.gfile.Rename(ready_file_temp, ready_file)
396

397
398
399
400
401
402
  if is_training:
    log_msg("Cycle {} complete. Total time: {:.1f} seconds"
            .format(train_cycle, timeit.default_timer() - st))
  else:
    log_msg("Eval construction complete. Total time: {:.1f} seconds"
            .format(timeit.default_timer() - st))
403
404


405
406
407
408
409
410
def _generation_loop(num_workers,           # type: int
                     cache_paths,           # type: rconst.Paths
                     num_readers,           # type: int
                     num_neg,               # type: int
                     num_train_positives,   # type: int
                     num_items,             # type: int
411
                     num_users,             # type: int
412
                     epochs_per_cycle,      # type: int
413
                     num_cycles,            # type: int
414
415
                     train_batch_size,      # type: int
                     eval_batch_size,       # type: int
416
417
                     deterministic,         # type: bool
                     match_mlperf           # type: bool
418
419
                    ):
  # type: (...) -> None
420
421
422
423
  """Primary run loop for data file generation."""

  log_msg("Entering generation loop.")
  tf.gfile.MakeDirs(cache_paths.train_epoch_dir)
424
  tf.gfile.MakeDirs(cache_paths.eval_data_subdir)
425
426
427
428

  training_shards = [os.path.join(cache_paths.train_shard_subdir, i) for i in
                     tf.gfile.ListDirectory(cache_paths.train_shard_subdir)]

429
430
431
432
433
434
435
  shared_kwargs = dict(
      num_workers=multiprocessing.cpu_count(), cache_paths=cache_paths,
      num_readers=num_readers, num_items=num_items,
      training_shards=training_shards, deterministic=deterministic,
      match_mlperf=match_mlperf
  )

436
437
438
  # Training blocks on the creation of the first epoch, so the num_workers
  # limit is not respected for this invocation
  train_cycle = 0
439
440
441
442
443
444
445
446
447
448
449
  _construct_records(
      is_training=True, train_cycle=train_cycle, num_neg=num_neg,
      num_positives=num_train_positives, epochs_per_cycle=epochs_per_cycle,
      batch_size=train_batch_size, **shared_kwargs)

  # Construct evaluation set.
  shared_kwargs["num_workers"] = num_workers
  _construct_records(
      is_training=False, train_cycle=None, num_neg=rconst.NUM_EVAL_NEGATIVES,
      num_positives=num_users, epochs_per_cycle=1, batch_size=eval_batch_size,
      **shared_kwargs)
450
451
452

  wait_count = 0
  start_time = time.time()
453
  while train_cycle < num_cycles:
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
    ready_epochs = tf.gfile.ListDirectory(cache_paths.train_epoch_dir)
    if len(ready_epochs) >= rconst.CYCLES_TO_BUFFER:
      wait_count += 1
      sleep_time = max([0, wait_count * 5 - (time.time() - start_time)])
      time.sleep(sleep_time)

      if (wait_count % 10) == 0:
        log_msg("Waited {} times for data to be consumed."
                .format(wait_count))

      if time.time() - start_time > rconst.TIMEOUT_SECONDS:
        log_msg("Waited more than {} seconds. Concluding that this "
                "process is orphaned and exiting gracefully."
                .format(rconst.TIMEOUT_SECONDS))
        sys.exit()

      continue

    train_cycle += 1
473
474
475
476
    _construct_records(
        is_training=True, train_cycle=train_cycle, num_neg=num_neg,
        num_positives=num_train_positives, epochs_per_cycle=epochs_per_cycle,
        batch_size=train_batch_size, **shared_kwargs)
477
478
479
480
481
482

    wait_count = 0
    start_time = time.time()
    gc.collect()


483
def wait_for_path(fpath):
484
  start_time = time.time()
485
  while not tf.gfile.Exists(fpath):
486
487
488
489
490
    if time.time() - start_time > rconst.TIMEOUT_SECONDS:
      log_msg("Waited more than {} seconds. Concluding that this "
              "process is orphaned and exiting gracefully."
              .format(rconst.TIMEOUT_SECONDS))
      sys.exit()
491
    time.sleep(1)
492
493
494
495
496
497

def _parse_flagfile(flagfile):
  """Fill flags with flagfile written by the main process."""
  tf.logging.info("Waiting for flagfile to appear at {}..."
                  .format(flagfile))
  wait_for_path(flagfile)
498
  tf.logging.info("flagfile found.")
499
500
501
502
503
504
505
506

  # `flags` module opens `flagfile` with `open`, which does not work on
  # google cloud storage etc.
  _, flagfile_temp = tempfile.mkstemp()
  tf.gfile.Copy(flagfile, flagfile_temp, overwrite=True)

  flags.FLAGS([__file__, "--flagfile", flagfile_temp])
  tf.gfile.Remove(flagfile_temp)
507
508


509
510
def write_alive_file(cache_paths):
  """Write file to signal that generation process started correctly."""
511
512
  wait_for_path(cache_paths.cache_root)

513
514
515
516
517
518
519
520
521
522
523
524
  log_msg("Signaling that I am alive.")
  with tf.gfile.Open(cache_paths.subproc_alive, "w") as f:
    f.write("Generation subproc has started.")

  @atexit.register
  def remove_alive_file():
    try:
      tf.gfile.Remove(cache_paths.subproc_alive)
    except tf.errors.NotFoundError:
      return  # Main thread has already deleted the entire cache dir.


525
def main(_):
526
527
528
529
  # Note: The async process must execute the following two steps in the
  #       following order BEFORE doing anything else:
  #       1) Write the alive file
  #       2) Wait for the flagfile to be written.
530
  global _log_file
531
532
  cache_paths = rconst.Paths(
      data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id)
533
  write_alive_file(cache_paths=cache_paths)
534

535
536
537
538
539
  flagfile = os.path.join(cache_paths.cache_root, rconst.FLAGFILE)
  _parse_flagfile(flagfile)

  redirect_logs = flags.FLAGS.redirect_logs

540
  log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id)
541
542
  log_path = os.path.join(cache_paths.data_dir, log_file_name)
  if log_path.startswith("gs://") and redirect_logs:
543
544
    fallback_log_file = os.path.join(tempfile.gettempdir(), log_file_name)
    print("Unable to log to {}. Falling back to {}"
545
546
          .format(log_path, fallback_log_file))
    log_path = fallback_log_file
547
548
549

  # This server is generally run in a subprocess.
  if redirect_logs:
550
551
552
    print("Redirecting output of data_async_generation.py process to {}"
          .format(log_path))
    _log_file = open(log_path, "wt")  # Note: not tf.gfile.Open().
553
  try:
554
555
556
557
558
    log_msg("sys.argv: {}".format(" ".join(sys.argv)))

    if flags.FLAGS.seed is not None:
      np.random.seed(flags.FLAGS.seed)

Reed's avatar
Reed committed
559
560
    with mlperf_helper.LOGGER(
        enable=flags.FLAGS.output_ml_perf_compliance_logging):
561
562
563
564
565
566
567
568
569
570
      mlperf_helper.set_ncf_root(os.path.split(os.path.abspath(__file__))[0])
      _generation_loop(
          num_workers=flags.FLAGS.num_workers,
          cache_paths=cache_paths,
          num_readers=flags.FLAGS.num_readers,
          num_neg=flags.FLAGS.num_neg,
          num_train_positives=flags.FLAGS.num_train_positives,
          num_items=flags.FLAGS.num_items,
          num_users=flags.FLAGS.num_users,
          epochs_per_cycle=flags.FLAGS.epochs_per_cycle,
571
          num_cycles=flags.FLAGS.num_cycles,
572
573
574
575
576
          train_batch_size=flags.FLAGS.train_batch_size,
          eval_batch_size=flags.FLAGS.eval_batch_size,
          deterministic=flags.FLAGS.seed is not None,
          match_mlperf=flags.FLAGS.ml_perf,
      )
577
578
579
580
581
  except KeyboardInterrupt:
    log_msg("KeyboardInterrupt registered.")
  except:
    traceback.print_exc(file=_log_file)
    raise
582
583
584
585
586
  finally:
    log_msg("Shutting down generation subprocess.")
    sys.stdout.flush()
    sys.stderr.flush()
    if redirect_logs:
587
      _log_file.close()
588
589
590


def define_flags():
591
  """Construct flags for the server."""
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
  flags.DEFINE_integer(name="num_workers", default=multiprocessing.cpu_count(),
                       help="Size of the negative generation worker pool.")
  flags.DEFINE_string(name="data_dir", default=None,
                      help="The data root. (used to construct cache paths.)")
  flags.DEFINE_string(name="cache_id", default=None,
                      help="The cache_id generated in the main process.")
  flags.DEFINE_integer(name="num_readers", default=4,
                       help="Number of reader datasets in training. This sets"
                            "how the epoch files are sharded.")
  flags.DEFINE_integer(name="num_neg", default=None,
                       help="The Number of negative instances to pair with a "
                            "positive instance.")
  flags.DEFINE_integer(name="num_train_positives", default=None,
                       help="The number of positive training examples.")
  flags.DEFINE_integer(name="num_items", default=None,
                       help="Number of items from which to select negatives.")
608
609
  flags.DEFINE_integer(name="num_users", default=None,
                       help="The number of unique users. Used for evaluation.")
610
611
612
  flags.DEFINE_integer(name="epochs_per_cycle", default=1,
                       help="The number of epochs of training data to produce"
                            "at a time.")
613
614
615
  flags.DEFINE_integer(name="num_cycles", default=None,
                       help="The number of cycles to produce training data "
                            "for.")
616
617
618
619
620
621
622
623
624
  flags.DEFINE_integer(name="train_batch_size", default=None,
                       help="The batch size with which training TFRecords will "
                            "be chunked.")
  flags.DEFINE_integer(name="eval_batch_size", default=None,
                       help="The batch size with which evaluation TFRecords "
                            "will be chunked.")
  flags.DEFINE_boolean(name="redirect_logs", default=False,
                       help="Catch logs and write them to a file. "
                            "(Useful if this is run as a subprocess)")
625
626
  flags.DEFINE_boolean(name="use_tf_logging", default=False,
                       help="Use tf.logging instead of log file.")
627
628
629
  flags.DEFINE_integer(name="seed", default=None,
                       help="NumPy random seed to set at startup. If not "
                            "specified, a seed will not be set.")
630
631
  flags.DEFINE_boolean(name="ml_perf", default=None,
                       help="Match MLPerf. See ncf_main.py for details.")
Reed's avatar
Reed committed
632
633
634
  flags.DEFINE_bool(name="output_ml_perf_compliance_logging", default=None,
                    help="Output the MLPerf compliance logging. See "
                         "ncf_main.py for details.")
635

636
  flags.mark_flags_as_required(["data_dir", "cache_id"])
637
638
639
640

if __name__ == "__main__":
  define_flags()
  absl_app.run(main)