Unverified Commit aec1fec6 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Fix/ncf eval default (#5438)

* improve default handling for eval_batch_size

* return eval_batch_size default to None

* fix syntax error
parent 505cad95
...@@ -128,8 +128,9 @@ def run_ncf(_): ...@@ -128,8 +128,9 @@ def run_ncf(_):
batch_size = distribution_utils.per_device_batch_size( batch_size = distribution_utils.per_device_batch_size(
int(FLAGS.batch_size), num_gpus) int(FLAGS.batch_size), num_gpus)
eval_batch_size = int(FLAGS.eval_batch_size or FLAGS.batch_size)
eval_per_user = rconst.NUM_EVAL_NEGATIVES + 1 eval_per_user = rconst.NUM_EVAL_NEGATIVES + 1
eval_batch_size = int(FLAGS.eval_batch_size or
max([FLAGS.batch_size, eval_per_user]))
if eval_batch_size % eval_per_user: if eval_batch_size % eval_per_user:
eval_batch_size = eval_batch_size // eval_per_user * eval_per_user eval_batch_size = eval_batch_size // eval_per_user * eval_per_user
tf.logging.warning( tf.logging.warning(
...@@ -365,7 +366,8 @@ def define_ncf_flags(): ...@@ -365,7 +366,8 @@ def define_ncf_flags():
@flags.validator("eval_batch_size", "eval_batch_size must be at least {}" @flags.validator("eval_batch_size", "eval_batch_size must be at least {}"
.format(rconst.NUM_EVAL_NEGATIVES + 1)) .format(rconst.NUM_EVAL_NEGATIVES + 1))
def eval_size_check(eval_batch_size): def eval_size_check(eval_batch_size):
return int(eval_batch_size) > rconst.NUM_EVAL_NEGATIVES return (eval_batch_size is None or
int(eval_batch_size) > rconst.NUM_EVAL_NEGATIVES)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment