Unverified Commit 20070ca4 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Wide Deep refactor and deep movies (#4506)

* begin branch

* finish download script

* rename download to dataset

* intermediate commit

* intermediate commit

* misc tweaks

* intermediate commit

* intermediate commit

* intermediate commit

* delint and update census test.

* add movie tests

* delint

* fix py2 issue

* address PR comments

* intermediate commit

* intermediate commit

* intermediate commit

* finish wide deep transition to vanilla movielens

* delint

* intermediate commit

* intermediate commit

* intermediate commit

* intermediate commit

* fix import

* add default ncf csv construction

* change default on download_if_missing

* shard and vectorize example serialization

* fix import

* update ncf data unittests

* delint

* delint

* more delinting

* fix wide-deep movielens serialization

* address PR comments

* add file_io tests

* investigate wide-deep test failure

* remove hard coded path and properly use flags.

* address file_io test PR comments

* missed a hash_bucked_size
parent 713228fd
...@@ -84,3 +84,10 @@ def generate_synthetic_data( ...@@ -84,3 +84,10 @@ def generate_synthetic_data(
element = (input_element, label_element) element = (input_element, label_element)
return tf.data.Dataset.from_tensors(element).repeat() return tf.data.Dataset.from_tensors(element).repeat()
def apply_clean(flags_obj):
if flags_obj.clean and tf.gfile.Exists(flags_obj.model_dir):
tf.logging.info("--clean flag set. Removing existing model dir: {}".format(
flags_obj.model_dir))
tf.gfile.DeleteRecursively(flags_obj.model_dir)
...@@ -12,22 +12,30 @@ ...@@ -12,22 +12,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Example code for TensorFlow Wide & Deep Tutorial using tf.estimator API.""" """Download and clean the Census Income Dataset."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import shutil import sys
# pylint: disable=wrong-import-order
from absl import app as absl_app from absl import app as absl_app
from absl import flags from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order from six.moves import urllib
import tensorflow as tf
# pylint: enable=wrong-import-order
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.logs import logger
from official.utils.misc import model_helpers DATA_URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult'
TRAINING_FILE = 'adult.data'
TRAINING_URL = '%s/%s' % (DATA_URL, TRAINING_FILE)
EVAL_FILE = 'adult.test'
EVAL_URL = '%s/%s' % (DATA_URL, EVAL_FILE)
_CSV_COLUMNS = [ _CSV_COLUMNS = [
...@@ -40,37 +48,47 @@ _CSV_COLUMNS = [ ...@@ -40,37 +48,47 @@ _CSV_COLUMNS = [
_CSV_COLUMN_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''], _CSV_COLUMN_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''],
[0], [0], [0], [''], ['']] [0], [0], [0], [''], ['']]
_HASH_BUCKET_SIZE = 1000
_NUM_EXAMPLES = { _NUM_EXAMPLES = {
'train': 32561, 'train': 32561,
'validation': 16281, 'validation': 16281,
} }
LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'} def _download_and_clean_file(filename, url):
"""Downloads data from url, and makes changes to match the CSV format."""
temp_file, _ = urllib.request.urlretrieve(url)
with tf.gfile.Open(temp_file, 'r') as temp_eval_file:
with tf.gfile.Open(filename, 'w') as eval_file:
for line in temp_eval_file:
line = line.strip()
line = line.replace(', ', ',')
if not line or ',' not in line:
continue
if line[-1] == '.':
line = line[:-1]
line += '\n'
eval_file.write(line)
tf.gfile.Remove(temp_file)
def define_wide_deep_flags():
"""Add supervised learning flags, as well as wide-deep model type."""
flags_core.define_base()
flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core) def download(data_dir):
"""Download census data if it is not already present."""
tf.gfile.MakeDirs(data_dir)
flags.DEFINE_enum( training_file_path = os.path.join(data_dir, TRAINING_FILE)
name="model_type", short_name="mt", default="wide_deep", if not tf.gfile.Exists(training_file_path):
enum_values=['wide', 'deep', 'wide_deep'], _download_and_clean_file(training_file_path, TRAINING_URL)
help="Select model topology.")
flags_core.set_defaults(data_dir='/tmp/census_data', eval_file_path = os.path.join(data_dir, EVAL_FILE)
model_dir='/tmp/census_model', if not tf.gfile.Exists(eval_file_path):
train_epochs=40, _download_and_clean_file(eval_file_path, EVAL_URL)
epochs_between_evals=2,
batch_size=40)
def build_model_columns(): def build_model_columns():
"""Builds a set of wide and deep feature columns.""" """Builds a set of wide and deep feature columns."""
# Continuous columns # Continuous variable columns
age = tf.feature_column.numeric_column('age') age = tf.feature_column.numeric_column('age')
education_num = tf.feature_column.numeric_column('education_num') education_num = tf.feature_column.numeric_column('education_num')
capital_gain = tf.feature_column.numeric_column('capital_gain') capital_gain = tf.feature_column.numeric_column('capital_gain')
...@@ -100,7 +118,7 @@ def build_model_columns(): ...@@ -100,7 +118,7 @@ def build_model_columns():
# To show an example of hashing: # To show an example of hashing:
occupation = tf.feature_column.categorical_column_with_hash_bucket( occupation = tf.feature_column.categorical_column_with_hash_bucket(
'occupation', hash_bucket_size=1000) 'occupation', hash_bucket_size=_HASH_BUCKET_SIZE)
# Transformations. # Transformations.
age_buckets = tf.feature_column.bucketized_column( age_buckets = tf.feature_column.bucketized_column(
...@@ -114,9 +132,10 @@ def build_model_columns(): ...@@ -114,9 +132,10 @@ def build_model_columns():
crossed_columns = [ crossed_columns = [
tf.feature_column.crossed_column( tf.feature_column.crossed_column(
['education', 'occupation'], hash_bucket_size=1000), ['education', 'occupation'], hash_bucket_size=_HASH_BUCKET_SIZE),
tf.feature_column.crossed_column( tf.feature_column.crossed_column(
[age_buckets, 'education', 'occupation'], hash_bucket_size=1000), [age_buckets, 'education', 'occupation'],
hash_bucket_size=_HASH_BUCKET_SIZE),
] ]
wide_columns = base_columns + crossed_columns wide_columns = base_columns + crossed_columns
...@@ -138,48 +157,19 @@ def build_model_columns(): ...@@ -138,48 +157,19 @@ def build_model_columns():
return wide_columns, deep_columns return wide_columns, deep_columns
def build_estimator(model_dir, model_type):
"""Build an estimator appropriate for the given model type."""
wide_columns, deep_columns = build_model_columns()
hidden_units = [100, 75, 50, 25]
# Create a tf.estimator.RunConfig to ensure the model is run on CPU, which
# trains faster than GPU for this model.
run_config = tf.estimator.RunConfig().replace(
session_config=tf.ConfigProto(device_count={'GPU': 0}))
if model_type == 'wide':
return tf.estimator.LinearClassifier(
model_dir=model_dir,
feature_columns=wide_columns,
config=run_config)
elif model_type == 'deep':
return tf.estimator.DNNClassifier(
model_dir=model_dir,
feature_columns=deep_columns,
hidden_units=hidden_units,
config=run_config)
else:
return tf.estimator.DNNLinearCombinedClassifier(
model_dir=model_dir,
linear_feature_columns=wide_columns,
dnn_feature_columns=deep_columns,
dnn_hidden_units=hidden_units,
config=run_config)
def input_fn(data_file, num_epochs, shuffle, batch_size): def input_fn(data_file, num_epochs, shuffle, batch_size):
"""Generate an input function for the Estimator.""" """Generate an input function for the Estimator."""
assert tf.gfile.Exists(data_file), ( assert tf.gfile.Exists(data_file), (
'%s not found. Please make sure you have run data_download.py and ' '%s not found. Please make sure you have run census_dataset.py and '
'set the --data_dir argument to the correct path.' % data_file) 'set the --data_dir argument to the correct path.' % data_file)
def parse_csv(value): def parse_csv(value):
print('Parsing', data_file) tf.logging.info('Parsing {}'.format(data_file))
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS) columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
features = dict(zip(_CSV_COLUMNS, columns)) features = dict(zip(_CSV_COLUMNS, columns))
labels = features.pop('income_bracket') labels = features.pop('income_bracket')
return features, tf.equal(labels, '>50K') classes = tf.equal(labels, '>50K') # binary classification
return features, classes
# Extract lines from input files using the Dataset API. # Extract lines from input files using the Dataset API.
dataset = tf.data.TextLineDataset(data_file) dataset = tf.data.TextLineDataset(data_file)
...@@ -196,96 +186,19 @@ def input_fn(data_file, num_epochs, shuffle, batch_size): ...@@ -196,96 +186,19 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
return dataset return dataset
def export_model(model, model_type, export_dir): def define_data_download_flags():
"""Export to SavedModel format. """Add flags specifying data download arguments."""
flags.DEFINE_string(
Args: name="data_dir", default="/tmp/census_data/",
model: Estimator object help=flags_core.help_wrap(
model_type: string indicating model type. "wide", "deep" or "wide_deep" "Directory to download and extract data."))
export_dir: directory to export the model.
"""
wide_columns, deep_columns = build_model_columns()
if model_type == 'wide':
columns = wide_columns
elif model_type == 'deep':
columns = deep_columns
else:
columns = wide_columns + deep_columns
feature_spec = tf.feature_column.make_parse_example_spec(columns)
example_input_fn = (
tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec))
model.export_savedmodel(export_dir, example_input_fn)
def run_wide_deep(flags_obj):
"""Run Wide-Deep training and eval loop.
Args:
flags_obj: An object containing parsed flag values.
"""
# Clean up the model directory if present
shutil.rmtree(flags_obj.model_dir, ignore_errors=True)
model = build_estimator(flags_obj.model_dir, flags_obj.model_type)
train_file = os.path.join(flags_obj.data_dir, 'adult.data')
test_file = os.path.join(flags_obj.data_dir, 'adult.test')
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
def train_input_fn():
return input_fn(
train_file, flags_obj.epochs_between_evals, True, flags_obj.batch_size)
def eval_input_fn():
return input_fn(test_file, 1, False, flags_obj.batch_size)
run_params = {
'batch_size': flags_obj.batch_size,
'train_epochs': flags_obj.train_epochs,
'model_type': flags_obj.model_type,
}
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info('wide_deep', 'Census Income', run_params,
test_id=flags_obj.benchmark_test_id)
loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '')
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks, batch_size=flags_obj.batch_size,
tensors_to_log={'average_loss': loss_prefix + 'head/truediv',
'loss': loss_prefix + 'head/weighted_loss/Sum'})
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
for n in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
model.train(input_fn=train_input_fn, hooks=train_hooks)
results = model.evaluate(input_fn=eval_input_fn)
# Display evaluation metrics
tf.logging.info('Results at epoch %d / %d',
(n + 1) * flags_obj.epochs_between_evals,
flags_obj.train_epochs)
tf.logging.info('-' * 60)
for key in sorted(results):
tf.logging.info('%s: %s' % (key, results[key]))
benchmark_logger.log_evaluation_result(results)
if model_helpers.past_stop_threshold(
flags_obj.stop_threshold, results['accuracy']):
break
# Export the model
if flags_obj.export_dir is not None:
export_model(model, flags_obj.model_type, flags_obj.export_dir)
def main(_): def main(_):
with logger.benchmark_context(flags.FLAGS): download(flags.FLAGS.data_dir)
run_wide_deep(flags.FLAGS)
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
define_wide_deep_flags() define_data_download_flags()
absl_app.run(main) absl_app.run(main)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train DNN on census income dataset."""
import os
from absl import app as absl_app
from absl import flags
import tensorflow as tf
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.wide_deep import census_dataset
from official.wide_deep import wide_deep_run_loop
def define_census_flags():
wide_deep_run_loop.define_wide_deep_flags()
flags.adopt_module_key_flags(wide_deep_run_loop)
flags_core.set_defaults(data_dir='/tmp/census_data',
model_dir='/tmp/census_model',
train_epochs=40,
epochs_between_evals=2,
batch_size=40)
def build_estimator(model_dir, model_type, model_column_fn):
"""Build an estimator appropriate for the given model type."""
wide_columns, deep_columns = model_column_fn()
hidden_units = [100, 75, 50, 25]
# Create a tf.estimator.RunConfig to ensure the model is run on CPU, which
# trains faster than GPU for this model.
run_config = tf.estimator.RunConfig().replace(
session_config=tf.ConfigProto(device_count={'GPU': 0}))
if model_type == 'wide':
return tf.estimator.LinearClassifier(
model_dir=model_dir,
feature_columns=wide_columns,
config=run_config)
elif model_type == 'deep':
return tf.estimator.DNNClassifier(
model_dir=model_dir,
feature_columns=deep_columns,
hidden_units=hidden_units,
config=run_config)
else:
return tf.estimator.DNNLinearCombinedClassifier(
model_dir=model_dir,
linear_feature_columns=wide_columns,
dnn_feature_columns=deep_columns,
dnn_hidden_units=hidden_units,
config=run_config)
def run_census(flags_obj):
"""Construct all necessary functions and call run_loop.
Args:
flags_obj: Object containing user specified flags.
"""
if flags_obj.download_if_missing:
census_dataset.download(flags_obj.data_dir)
train_file = os.path.join(flags_obj.data_dir, census_dataset.TRAINING_FILE)
test_file = os.path.join(flags_obj.data_dir, census_dataset.EVAL_FILE)
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
def train_input_fn():
return census_dataset.input_fn(
train_file, flags_obj.epochs_between_evals, True, flags_obj.batch_size)
def eval_input_fn():
return census_dataset.input_fn(test_file, 1, False, flags_obj.batch_size)
tensors_to_log = {
'average_loss': '{loss_prefix}head/truediv',
'loss': '{loss_prefix}head/weighted_loss/Sum'
}
wide_deep_run_loop.run_loop(
name="Census Income", train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
model_column_fn=census_dataset.build_model_columns,
build_estimator_fn=build_estimator,
flags_obj=flags_obj,
tensors_to_log=tensors_to_log,
early_stop=True)
def main(_):
with logger.benchmark_context(flags.FLAGS):
run_census(flags.FLAGS)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
define_census_flags()
absl_app.run(main)
...@@ -22,7 +22,9 @@ import os ...@@ -22,7 +22,9 @@ import os
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.testing import integration from official.utils.testing import integration
from official.wide_deep import wide_deep from official.wide_deep import census_dataset
from official.wide_deep import census_main
from official.wide_deep import wide_deep_run_loop
tf.logging.set_verbosity(tf.logging.ERROR) tf.logging.set_verbosity(tf.logging.ERROR)
...@@ -42,7 +44,7 @@ TEST_INPUT_VALUES = { ...@@ -42,7 +44,7 @@ TEST_INPUT_VALUES = {
'occupation': 'abc', 'occupation': 'abc',
} }
TEST_CSV = os.path.join(os.path.dirname(__file__), 'wide_deep_test.csv') TEST_CSV = os.path.join(os.path.dirname(__file__), 'census_test.csv')
class BaseTest(tf.test.TestCase): class BaseTest(tf.test.TestCase):
...@@ -51,7 +53,7 @@ class BaseTest(tf.test.TestCase): ...@@ -51,7 +53,7 @@ class BaseTest(tf.test.TestCase):
@classmethod @classmethod
def setUpClass(cls): # pylint: disable=invalid-name def setUpClass(cls): # pylint: disable=invalid-name
super(BaseTest, cls).setUpClass() super(BaseTest, cls).setUpClass()
wide_deep.define_wide_deep_flags() census_main.define_census_flags()
def setUp(self): def setUp(self):
# Create temporary CSV file # Create temporary CSV file
...@@ -64,15 +66,15 @@ class BaseTest(tf.test.TestCase): ...@@ -64,15 +66,15 @@ class BaseTest(tf.test.TestCase):
test_csv_contents = temp_csv.read() test_csv_contents = temp_csv.read()
# Used for end-to-end tests. # Used for end-to-end tests.
for fname in ['adult.data', 'adult.test']: for fname in [census_dataset.TRAINING_FILE, census_dataset.EVAL_FILE]:
with tf.gfile.Open(os.path.join(self.temp_dir, fname), 'w') as test_csv: with tf.gfile.Open(os.path.join(self.temp_dir, fname), 'w') as test_csv:
test_csv.write(test_csv_contents) test_csv.write(test_csv_contents)
def test_input_fn(self): def test_input_fn(self):
dataset = wide_deep.input_fn(self.input_csv, 1, False, 1) dataset = census_dataset.input_fn(self.input_csv, 1, False, 1)
features, labels = dataset.make_one_shot_iterator().get_next() features, labels = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess: with self.test_session() as sess:
features, labels = sess.run((features, labels)) features, labels = sess.run((features, labels))
# Compare the two features dictionaries. # Compare the two features dictionaries.
...@@ -91,12 +93,14 @@ class BaseTest(tf.test.TestCase): ...@@ -91,12 +93,14 @@ class BaseTest(tf.test.TestCase):
def build_and_test_estimator(self, model_type): def build_and_test_estimator(self, model_type):
"""Ensure that model trains and minimizes loss.""" """Ensure that model trains and minimizes loss."""
model = wide_deep.build_estimator(self.temp_dir, model_type) model = census_main.build_estimator(
self.temp_dir, model_type,
model_column_fn=census_dataset.build_model_columns)
# Train for 1 step to initialize model and evaluate initial loss # Train for 1 step to initialize model and evaluate initial loss
def get_input_fn(num_epochs, shuffle, batch_size): def get_input_fn(num_epochs, shuffle, batch_size):
def input_fn(): def input_fn():
return wide_deep.input_fn( return census_dataset.input_fn(
TEST_CSV, num_epochs=num_epochs, shuffle=shuffle, TEST_CSV, num_epochs=num_epochs, shuffle=shuffle,
batch_size=batch_size) batch_size=batch_size)
return input_fn return input_fn
...@@ -123,25 +127,31 @@ class BaseTest(tf.test.TestCase): ...@@ -123,25 +127,31 @@ class BaseTest(tf.test.TestCase):
def test_end_to_end_wide(self): def test_end_to_end_wide(self):
integration.run_synthetic( integration.run_synthetic(
main=wide_deep.main, tmp_root=self.get_temp_dir(), extra_flags=[ main=census_main.main, tmp_root=self.get_temp_dir(),
extra_flags=[
'--data_dir', self.get_temp_dir(), '--data_dir', self.get_temp_dir(),
'--model_type', 'wide', '--model_type', 'wide',
'--download_if_missing=false'
], ],
synth=False, max_train=None) synth=False, max_train=None)
def test_end_to_end_deep(self): def test_end_to_end_deep(self):
integration.run_synthetic( integration.run_synthetic(
main=wide_deep.main, tmp_root=self.get_temp_dir(), extra_flags=[ main=census_main.main, tmp_root=self.get_temp_dir(),
extra_flags=[
'--data_dir', self.get_temp_dir(), '--data_dir', self.get_temp_dir(),
'--model_type', 'deep', '--model_type', 'deep',
'--download_if_missing=false'
], ],
synth=False, max_train=None) synth=False, max_train=None)
def test_end_to_end_wide_deep(self): def test_end_to_end_wide_deep(self):
integration.run_synthetic( integration.run_synthetic(
main=wide_deep.main, tmp_root=self.get_temp_dir(), extra_flags=[ main=census_main.main, tmp_root=self.get_temp_dir(),
extra_flags=[
'--data_dir', self.get_temp_dir(), '--data_dir', self.get_temp_dir(),
'--model_type', 'wide_deep', '--model_type', 'wide_deep',
'--download_if_missing=false'
], ],
synth=False, max_train=None) synth=False, max_train=None)
......
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Download and clean the Census Income Dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
from six.moves import urllib
import tensorflow as tf
DATA_URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult'
TRAINING_FILE = 'adult.data'
TRAINING_URL = '%s/%s' % (DATA_URL, TRAINING_FILE)
EVAL_FILE = 'adult.test'
EVAL_URL = '%s/%s' % (DATA_URL, EVAL_FILE)
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_dir', type=str, default='/tmp/census_data',
help='Directory to download census data')
def _download_and_clean_file(filename, url):
"""Downloads data from url, and makes changes to match the CSV format."""
temp_file, _ = urllib.request.urlretrieve(url)
with tf.gfile.Open(temp_file, 'r') as temp_eval_file:
with tf.gfile.Open(filename, 'w') as eval_file:
for line in temp_eval_file:
line = line.strip()
line = line.replace(', ', ',')
if not line or ',' not in line:
continue
if line[-1] == '.':
line = line[:-1]
line += '\n'
eval_file.write(line)
tf.gfile.Remove(temp_file)
def main(_):
if not tf.gfile.Exists(FLAGS.data_dir):
tf.gfile.MkDir(FLAGS.data_dir)
training_file_path = os.path.join(FLAGS.data_dir, TRAINING_FILE)
_download_and_clean_file(training_file_path, TRAINING_URL)
eval_file_path = os.path.join(FLAGS.data_dir, EVAL_FILE)
_download_and_clean_file(eval_file_path, EVAL_URL)
if __name__ == '__main__':
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(argv=[sys.argv[0]] + unparsed)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Prepare MovieLens dataset for wide-deep."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import os
# pylint: disable=wrong-import-order
from absl import app as absl_app
from absl import flags
import numpy as np
import tensorflow as tf
# pylint: enable=wrong-import-order
from official.datasets import movielens
from official.utils.data import file_io
from official.utils.flags import core as flags_core
_BUFFER_SUBDIR = "wide_deep_buffer"
_FEATURE_MAP = {
movielens.USER_COLUMN: tf.FixedLenFeature([1], dtype=tf.int64),
movielens.ITEM_COLUMN: tf.FixedLenFeature([1], dtype=tf.int64),
movielens.TIMESTAMP_COLUMN: tf.FixedLenFeature([1], dtype=tf.int64),
movielens.GENRE_COLUMN: tf.FixedLenFeature(
[movielens.N_GENRE], dtype=tf.int64),
movielens.RATING_COLUMN: tf.FixedLenFeature([1], dtype=tf.float32),
}
_BUFFER_SIZE = {
movielens.ML_1M: {"train": 107978119, "eval": 26994538},
movielens.ML_20M: {"train": 2175203810, "eval": 543802008}
}
_USER_EMBEDDING_DIM = 16
_ITEM_EMBEDDING_DIM = 64
def build_model_columns(dataset):
"""Builds a set of wide and deep feature columns."""
user_id = tf.feature_column.categorical_column_with_vocabulary_list(
movielens.USER_COLUMN, range(1, movielens.NUM_USER_IDS[dataset]))
user_embedding = tf.feature_column.embedding_column(
user_id, _USER_EMBEDDING_DIM, max_norm=np.sqrt(_USER_EMBEDDING_DIM))
item_id = tf.feature_column.categorical_column_with_vocabulary_list(
movielens.ITEM_COLUMN, range(1, movielens.NUM_ITEM_IDS))
item_embedding = tf.feature_column.embedding_column(
item_id, _ITEM_EMBEDDING_DIM, max_norm=np.sqrt(_ITEM_EMBEDDING_DIM))
time = tf.feature_column.numeric_column(movielens.TIMESTAMP_COLUMN)
genres = tf.feature_column.numeric_column(
movielens.GENRE_COLUMN, shape=(movielens.N_GENRE,), dtype=tf.uint8)
deep_columns = [user_embedding, item_embedding, time, genres]
wide_columns = []
return wide_columns, deep_columns
def _deserialize(examples_serialized):
features = tf.parse_example(examples_serialized, _FEATURE_MAP)
return features, features[movielens.RATING_COLUMN] / movielens.MAX_RATING
def _buffer_path(data_dir, dataset, name):
return os.path.join(data_dir, _BUFFER_SUBDIR,
"{}_{}_buffer".format(dataset, name))
def _df_to_input_fn(df, name, dataset, data_dir, batch_size, repeat, shuffle):
"""Serialize a dataframe and write it to a buffer file."""
buffer_path = _buffer_path(data_dir, dataset, name)
expected_size = _BUFFER_SIZE[dataset].get(name)
file_io.write_to_buffer(
dataframe=df, buffer_path=buffer_path,
columns=list(_FEATURE_MAP.keys()), expected_size=expected_size)
def input_fn():
dataset = tf.data.TFRecordDataset(buffer_path)
# batch comes before map because map can deserialize multiple examples.
dataset = dataset.batch(batch_size)
dataset = dataset.map(_deserialize, num_parallel_calls=16)
if shuffle:
dataset = dataset.shuffle(shuffle)
dataset = dataset.repeat(repeat)
return dataset.prefetch(1)
return input_fn
def _check_buffers(data_dir, dataset):
train_path = os.path.join(data_dir, _BUFFER_SUBDIR,
"{}_{}_buffer".format(dataset, "train"))
eval_path = os.path.join(data_dir, _BUFFER_SUBDIR,
"{}_{}_buffer".format(dataset, "eval"))
if not tf.gfile.Exists(train_path) or not tf.gfile.Exists(eval_path):
return False
return all([
tf.gfile.Stat(_buffer_path(data_dir, dataset, "train")).length ==
_BUFFER_SIZE[dataset]["train"],
tf.gfile.Stat(_buffer_path(data_dir, dataset, "eval")).length ==
_BUFFER_SIZE[dataset]["eval"],
])
def construct_input_fns(dataset, data_dir, batch_size=16, repeat=1):
"""Construct train and test input functions, as well as the column fn."""
if _check_buffers(data_dir, dataset):
train_df, eval_df = None, None
else:
df = movielens.csv_to_joint_dataframe(dataset=dataset, data_dir=data_dir)
df = movielens.integerize_genres(dataframe=df)
df = df.drop(columns=[movielens.TITLE_COLUMN])
train_df = df.sample(frac=0.8, random_state=0)
eval_df = df.drop(train_df.index)
train_df = train_df.reset_index(drop=True)
eval_df = eval_df.reset_index(drop=True)
train_input_fn = _df_to_input_fn(
df=train_df, name="train", dataset=dataset, data_dir=data_dir,
batch_size=batch_size, repeat=repeat,
shuffle=movielens.NUM_RATINGS[dataset])
eval_input_fn = _df_to_input_fn(
df=eval_df, name="eval", dataset=dataset, data_dir=data_dir,
batch_size=batch_size, repeat=repeat, shuffle=None)
model_column_fn = functools.partial(build_model_columns, dataset=dataset)
train_input_fn()
return train_input_fn, eval_input_fn, model_column_fn
def main(_):
movielens.download(dataset=flags.FLAGS.dataset, data_dir=flags.FLAGS.data_dir)
construct_input_fns(flags.FLAGS.dataset, flags.FLAGS.data_dir)
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
movielens.define_data_download_flags()
flags.adopt_module_key_flags(movielens)
flags_core.set_defaults(dataset="ml-1m")
absl_app.run(main)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train DNN on Kaggle movie dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app as absl_app
from absl import flags
import tensorflow as tf
from official.datasets import movielens
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.wide_deep import movielens_dataset
from official.wide_deep import wide_deep_run_loop
def define_movie_flags():
"""Define flags for movie dataset training."""
wide_deep_run_loop.define_wide_deep_flags()
flags.DEFINE_enum(
name="dataset", default=movielens.ML_1M,
enum_values=movielens.DATASETS, case_sensitive=False,
help=flags_core.help_wrap("Dataset to be trained and evaluated."))
flags.adopt_module_key_flags(wide_deep_run_loop)
flags_core.set_defaults(data_dir="/tmp/movielens-data/",
model_dir='/tmp/movie_model',
model_type="deep",
train_epochs=50,
epochs_between_evals=5,
batch_size=256)
@flags.validator("stop_threshold",
message="stop_threshold not supported for movielens model")
def _no_stop(stop_threshold):
return stop_threshold is None
def build_estimator(model_dir, model_type, model_column_fn):
"""Build an estimator appropriate for the given model type."""
if model_type != "deep":
raise NotImplementedError("movie dataset only supports `deep` model_type")
_, deep_columns = model_column_fn()
hidden_units = [256, 256, 256, 128]
return tf.estimator.DNNRegressor(
model_dir=model_dir,
feature_columns=deep_columns,
hidden_units=hidden_units,
optimizer=tf.train.AdamOptimizer(),
activation_fn=tf.nn.sigmoid,
dropout=0.3,
loss_reduction=tf.losses.Reduction.MEAN)
def run_movie(flags_obj):
"""Construct all necessary functions and call run_loop.
Args:
flags_obj: Object containing user specified flags.
"""
if flags_obj.download_if_missing:
movielens.download(dataset=flags_obj.dataset, data_dir=flags_obj.data_dir)
train_input_fn, eval_input_fn, model_column_fn = \
movielens_dataset.construct_input_fns(
dataset=flags_obj.dataset, data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size, repeat=flags_obj.epochs_between_evals)
tensors_to_log = {
'loss': '{loss_prefix}head/weighted_loss/value'
}
wide_deep_run_loop.run_loop(
name="MovieLens", train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
model_column_fn=model_column_fn,
build_estimator_fn=build_estimator,
flags_obj=flags_obj,
tensors_to_log=tensors_to_log,
early_stop=False)
def main(_):
with logger.benchmark_context(flags.FLAGS):
run_movie(flags.FLAGS)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
define_movie_flags()
absl_app.run(main)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.datasets import movielens
from official.utils.testing import integration
from official.wide_deep import movielens_dataset
from official.wide_deep import movielens_main
from official.wide_deep import wide_deep_run_loop
tf.logging.set_verbosity(tf.logging.ERROR)
TEST_INPUT_VALUES = {
"genres": np.array(
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
"user_id": [3],
"item_id": [4],
}
TEST_ITEM_DATA = """item_id,titles,genres
1,Movie_1,Comedy|Romance
2,Movie_2,Adventure|Children's
3,Movie_3,Comedy|Drama
4,Movie_4,Comedy
5,Movie_5,Action|Crime|Thriller
6,Movie_6,Action
7,Movie_7,Action|Adventure|Thriller"""
TEST_RATING_DATA = """user_id,item_id,rating,timestamp
1,2,5,978300760
1,3,3,978302109
1,6,3,978301968
2,1,4,978300275
2,7,5,978824291
3,1,3,978302268
3,4,5,978302039
3,5,5,978300719
"""
class BaseTest(tf.test.TestCase):
"""Tests for Wide Deep model."""
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(BaseTest, cls).setUpClass()
movielens_main.define_movie_flags()
def setUp(self):
# Create temporary CSV file
self.temp_dir = self.get_temp_dir()
tf.gfile.MakeDirs(os.path.join(self.temp_dir, movielens.ML_1M))
self.ratings_csv = os.path.join(
self.temp_dir, movielens.ML_1M, movielens.RATINGS_FILE)
self.item_csv = os.path.join(
self.temp_dir, movielens.ML_1M, movielens.MOVIES_FILE)
with tf.gfile.Open(self.ratings_csv, "w") as f:
f.write(TEST_RATING_DATA)
with tf.gfile.Open(self.item_csv, "w") as f:
f.write(TEST_ITEM_DATA)
def test_input_fn(self):
train_input_fn, _, _ = movielens_dataset.construct_input_fns(
dataset=movielens.ML_1M, data_dir=self.temp_dir, batch_size=8, repeat=1)
dataset = train_input_fn()
features, labels = dataset.make_one_shot_iterator().get_next()
with self.test_session() as sess:
features, labels = sess.run((features, labels))
# Compare the two features dictionaries.
for key in TEST_INPUT_VALUES:
self.assertTrue(key in features)
self.assertAllClose(TEST_INPUT_VALUES[key], features[key][0])
self.assertAllClose(labels[0], [1.0])
def test_end_to_end_deep(self):
integration.run_synthetic(
main=movielens_main.main, tmp_root=self.temp_dir,
extra_flags=[
"--data_dir", self.temp_dir,
"--download_if_missing=false",
"--train_epochs", "1",
"--epochs_between_evals", "1"
],
synth=False, max_train=None)
if __name__ == "__main__":
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Core run logic for TensorFlow Wide & Deep Tutorial using tf.estimator API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.logs import logger
from official.utils.misc import model_helpers
LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'}
def define_wide_deep_flags():
"""Add supervised learning flags, as well as wide-deep model type."""
flags_core.define_base()
flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core)
flags.DEFINE_enum(
name="model_type", short_name="mt", default="wide_deep",
enum_values=['wide', 'deep', 'wide_deep'],
help="Select model topology.")
flags.DEFINE_boolean(
name="download_if_missing", default=True, help=flags_core.help_wrap(
"Download data to data_dir if it is not already present."))
def export_model(model, model_type, export_dir, model_column_fn):
"""Export to SavedModel format.
Args:
model: Estimator object
model_type: string indicating model type. "wide", "deep" or "wide_deep"
export_dir: directory to export the model.
model_column_fn: Function to generate model feature columns.
"""
wide_columns, deep_columns = model_column_fn()
if model_type == 'wide':
columns = wide_columns
elif model_type == 'deep':
columns = deep_columns
else:
columns = wide_columns + deep_columns
feature_spec = tf.feature_column.make_parse_example_spec(columns)
example_input_fn = (
tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec))
model.export_savedmodel(export_dir, example_input_fn)
def run_loop(name, train_input_fn, eval_input_fn, model_column_fn,
build_estimator_fn, flags_obj, tensors_to_log, early_stop=False):
"""Define training loop."""
model_helpers.apply_clean(flags.FLAGS)
model = build_estimator_fn(
model_dir=flags_obj.model_dir, model_type=flags_obj.model_type,
model_column_fn=model_column_fn)
run_params = {
'batch_size': flags_obj.batch_size,
'train_epochs': flags_obj.train_epochs,
'model_type': flags_obj.model_type,
}
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info('wide_deep', name, run_params,
test_id=flags_obj.benchmark_test_id)
loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '')
tensors_to_log = {k: v.format(loss_prefix=loss_prefix)
for k, v in tensors_to_log.items()}
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks, model_dir=flags_obj.model_dir,
batch_size=flags_obj.batch_size, tensors_to_log=tensors_to_log)
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
for n in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
model.train(input_fn=train_input_fn, hooks=train_hooks)
results = model.evaluate(input_fn=eval_input_fn)
# Display evaluation metrics
tf.logging.info('Results at epoch %d / %d',
(n + 1) * flags_obj.epochs_between_evals,
flags_obj.train_epochs)
tf.logging.info('-' * 60)
for key in sorted(results):
tf.logging.info('%s: %s' % (key, results[key]))
benchmark_logger.log_evaluation_result(results)
if early_stop and model_helpers.past_stop_threshold(
flags_obj.stop_threshold, results['accuracy']):
break
# Export the model
if flags_obj.export_dir is not None:
export_model(model, flags_obj.model_type, flags_obj.export_dir,
model_column_fn)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment