Commit 0bfdad07 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Makes the default batch_size=99000.

PiperOrigin-RevId: 306453144
parent 0672bdcf
...@@ -331,7 +331,7 @@ class DatasetManager(object): ...@@ -331,7 +331,7 @@ class DatasetManager(object):
"""Returns batches for training.""" """Returns batches for training."""
# Estimator passes batch_size during training and eval_batch_size during # Estimator passes batch_size during training and eval_batch_size during
# eval. TPUEstimator only passes batch_size. # eval.
param_batch_size = (params["batch_size"] if self._is_training else param_batch_size = (params["batch_size"] if self._is_training else
params.get("eval_batch_size") or params["batch_size"]) params.get("eval_batch_size") or params["batch_size"])
if batch_size != param_batch_size: if batch_size != param_batch_size:
...@@ -713,7 +713,7 @@ class DummyConstructor(threading.Thread): ...@@ -713,7 +713,7 @@ class DummyConstructor(threading.Thread):
"""Returns dummy input batches for training.""" """Returns dummy input batches for training."""
# Estimator passes batch_size during training and eval_batch_size during # Estimator passes batch_size during training and eval_batch_size during
# eval. TPUEstimator only passes batch_size. # eval.
batch_size = (params["batch_size"] if is_training else batch_size = (params["batch_size"] if is_training else
params.get("eval_batch_size") or params["batch_size"]) params.get("eval_batch_size") or params["batch_size"])
num_users = params["num_users"] num_users = params["num_users"]
......
...@@ -167,7 +167,7 @@ def define_ncf_flags(): ...@@ -167,7 +167,7 @@ def define_ncf_flags():
model_dir="/tmp/ncf/", model_dir="/tmp/ncf/",
data_dir="/tmp/movielens-data/", data_dir="/tmp/movielens-data/",
train_epochs=2, train_epochs=2,
batch_size=256, batch_size=99000,
hooks="ProfilerHook", hooks="ProfilerHook",
tpu=None tpu=None
) )
......
#!/bin/bash
set -e
# Example settings:
# export TPU="taylorrobie-tpu-0"
# export BUCKET="gs://taylorrobie-tpu-test-bucket-2"
# Remove IDE "not assigned" warning highlights.
TPU=${TPU:-""}
BUCKET=${BUCKET:-""}
if [[ -z ${TPU} ]]; then
echo "Please set 'TPU' to the name of the TPU to be used."
exit 1
fi
if [[ -z ${BUCKET} ]]; then
echo "Please set 'BUCKET' to the GCS bucket to be used."
exit 1
fi
./run.sh
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment