Unverified Commit a6e1b895 authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Merge pull request #2732 from tensorflow/wide-deep-shuffle-size

Shuffle an exact epoch of 32561 examples in wide_deep.py rather than 100k
parents 5f0776a2 6d447202
...@@ -61,7 +61,10 @@ parser.add_argument( ...@@ -61,7 +61,10 @@ parser.add_argument(
'--test_data', type=str, default='/tmp/census_data/adult.test', '--test_data', type=str, default='/tmp/census_data/adult.test',
help='Path to the test data.') help='Path to the test data.')
_SHUFFLE_BUFFER = 100000 _NUM_EXAMPLES = {
'train': 32561,
'validation': 16281,
}
def build_model_columns(): def build_model_columns():
...@@ -181,7 +184,7 @@ def input_fn(data_file, num_epochs, shuffle, batch_size): ...@@ -181,7 +184,7 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
dataset = tf.data.TextLineDataset(data_file) dataset = tf.data.TextLineDataset(data_file)
if shuffle: if shuffle:
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER) dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])
dataset = dataset.map(parse_csv, num_parallel_calls=5) dataset = dataset.map(parse_csv, num_parallel_calls=5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment