wide_deep.py 8.07 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
田传武's avatar
田传武 committed
15
"""Example code for TensorFlow Wide & Deep Tutorial using tf.estimator API."""
16
17
18
19
20
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
21
import os
22
23
24
import shutil
import sys

Karmel Allison's avatar
Karmel Allison committed
25
import tensorflow as tf  # pylint: disable=g-bad-import-order
26

Karmel Allison's avatar
Karmel Allison committed
27
from official.utils.arg_parsers import parsers
28
29
from official.utils.logging import hooks_helper

30
31
32
33
34
35
36
37
38
39
_CSV_COLUMNS = [
    'age', 'workclass', 'fnlwgt', 'education', 'education_num',
    'marital_status', 'occupation', 'relationship', 'race', 'gender',
    'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
    'income_bracket'
]

_CSV_COLUMN_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''],
                        [0], [0], [0], [''], ['']]

40
41
42
43
_NUM_EXAMPLES = {
    'train': 32561,
    'validation': 16281,
}
Neal Wu's avatar
Neal Wu committed
44

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

def build_model_columns():
  """Builds a set of wide and deep feature columns."""
  # Continuous columns
  age = tf.feature_column.numeric_column('age')
  education_num = tf.feature_column.numeric_column('education_num')
  capital_gain = tf.feature_column.numeric_column('capital_gain')
  capital_loss = tf.feature_column.numeric_column('capital_loss')
  hours_per_week = tf.feature_column.numeric_column('hours_per_week')

  education = tf.feature_column.categorical_column_with_vocabulary_list(
      'education', [
          'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college',
          'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school',
          '5th-6th', '10th', '1st-4th', 'Preschool', '12th'])

  marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
      'marital_status', [
          'Married-civ-spouse', 'Divorced', 'Married-spouse-absent',
          'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed'])

  relationship = tf.feature_column.categorical_column_with_vocabulary_list(
      'relationship', [
          'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried',
          'Other-relative'])

  workclass = tf.feature_column.categorical_column_with_vocabulary_list(
      'workclass', [
          'Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov',
          'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'])

  # To show an example of hashing:
  occupation = tf.feature_column.categorical_column_with_hash_bucket(
      'occupation', hash_bucket_size=1000)

  # Transformations.
  age_buckets = tf.feature_column.bucketized_column(
      age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])

  # Wide columns and deep columns.
  base_columns = [
      education, marital_status, relationship, workclass, occupation,
      age_buckets,
  ]

  crossed_columns = [
      tf.feature_column.crossed_column(
          ['education', 'occupation'], hash_bucket_size=1000),
      tf.feature_column.crossed_column(
          [age_buckets, 'education', 'occupation'], hash_bucket_size=1000),
  ]

  wide_columns = base_columns + crossed_columns

  deep_columns = [
      age,
      education_num,
      capital_gain,
      capital_loss,
      hours_per_week,
      tf.feature_column.indicator_column(workclass),
      tf.feature_column.indicator_column(education),
      tf.feature_column.indicator_column(marital_status),
      tf.feature_column.indicator_column(relationship),
      # To show an example of embedding
      tf.feature_column.embedding_column(occupation, dimension=8),
  ]

  return wide_columns, deep_columns


def build_estimator(model_dir, model_type):
  """Build an estimator appropriate for the given model type."""
  wide_columns, deep_columns = build_model_columns()
  hidden_units = [100, 75, 50, 25]

  # Create a tf.estimator.RunConfig to ensure the model is run on CPU, which
  # trains faster than GPU for this model.
  run_config = tf.estimator.RunConfig().replace(
      session_config=tf.ConfigProto(device_count={'GPU': 0}))

  if model_type == 'wide':
    return tf.estimator.LinearClassifier(
        model_dir=model_dir,
        feature_columns=wide_columns,
        config=run_config)
  elif model_type == 'deep':
    return tf.estimator.DNNClassifier(
        model_dir=model_dir,
        feature_columns=deep_columns,
        hidden_units=hidden_units,
        config=run_config)
  else:
    return tf.estimator.DNNLinearCombinedClassifier(
        model_dir=model_dir,
        linear_feature_columns=wide_columns,
        dnn_feature_columns=deep_columns,
        dnn_hidden_units=hidden_units,
        config=run_config)


def input_fn(data_file, num_epochs, shuffle, batch_size):
  """Generate an input function for the Estimator."""
  assert tf.gfile.Exists(data_file), (
149
150
      '%s not found. Please make sure you have run data_download.py and '
      'set the --data_dir argument to the correct path.' % data_file)
Neal Wu's avatar
Neal Wu committed
151

152
153
154
155
156
157
158
159
  def parse_csv(value):
    print('Parsing', data_file)
    columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
    features = dict(zip(_CSV_COLUMNS, columns))
    labels = features.pop('income_bracket')
    return features, tf.equal(labels, '>50K')

  # Extract lines from input files using the Dataset API.
160
  dataset = tf.data.TextLineDataset(data_file)
161

Neal Wu's avatar
Neal Wu committed
162
  if shuffle:
163
    dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])
Neal Wu's avatar
Neal Wu committed
164

165
166
  dataset = dataset.map(parse_csv, num_parallel_calls=5)

Neal Wu's avatar
Neal Wu committed
167
168
  # We call repeat after shuffling, rather than before, to prevent separate
  # epochs from blending together.
169
  dataset = dataset.repeat(num_epochs)
Neal Wu's avatar
Neal Wu committed
170
  dataset = dataset.batch(batch_size)
171
  return dataset
172

173

Karmel Allison's avatar
Karmel Allison committed
174
def main(_):
175
176
177
178
  # Clean up the model directory if present
  shutil.rmtree(FLAGS.model_dir, ignore_errors=True)
  model = build_estimator(FLAGS.model_dir, FLAGS.model_type)

179
180
181
  train_file = os.path.join(FLAGS.data_dir, 'adult.data')
  test_file = os.path.join(FLAGS.data_dir, 'adult.test')

Karmel Allison's avatar
Karmel Allison committed
182
183
184
185
186
187
188
  # Train and evaluate the model every `FLAGS.epochs_per_eval` epochs.
  def train_input_fn():
    return input_fn(train_file, FLAGS.epochs_per_eval, True, FLAGS.batch_size)

  def eval_input_fn():
    return input_fn(test_file, 1, False, FLAGS.batch_size)

189
190
191
192
193
194
195
  train_hooks = hooks_helper.get_train_hooks(
      FLAGS.hooks, batch_size=FLAGS.batch_size,
      tensors_to_log={'average_loss': 'head/truediv',
                      'loss': 'head/weighted_loss/Sum'})

  # Train and evaluate the model every `FLAGS.epochs_between_evals` epochs.
  for n in range(FLAGS.train_epochs // FLAGS.epochs_between_evals):
Karmel Allison's avatar
Karmel Allison committed
196
197
    model.train(input_fn=train_input_fn, hooks=train_hooks)
    results = model.evaluate(input_fn=eval_input_fn)
198
199

    # Display evaluation metrics
200
    print('Results at epoch', (n + 1) * FLAGS.epochs_between_evals)
201
    print('-' * 60)
Neal Wu's avatar
Neal Wu committed
202

203
204
205
206
    for key in sorted(results):
      print('%s: %s' % (key, results[key]))


207
208
class WideDeepArgParser(argparse.ArgumentParser):
  """Argument parser for running the wide deep model."""
Karmel Allison's avatar
Karmel Allison committed
209

210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
  def __init__(self):
    super(WideDeepArgParser, self).__init__(parents=[parsers.BaseParser()])
    self.add_argument(
        '--model_type', '-mt', type=str, default='wide_deep',
        choices=['wide', 'deep', 'wide_deep'],
        help='[default %(default)s] Valid model types: wide, deep, wide_deep.',
        metavar='<MT>')
    self.set_defaults(
        data_dir='/tmp/census_data',
        model_dir='/tmp/census_model',
        train_epochs=40,
        epochs_between_evals=2,
        batch_size=40)


225
226
if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
227
  parser = WideDeepArgParser()
228
229
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)