wide_deep.py 9.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
田传武's avatar
田传武 committed
15
"""Example code for TensorFlow Wide & Deep Tutorial using tf.estimator API."""
16
17
18
19
20
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
21
import os
22
23
24
import shutil
import sys

Karmel Allison's avatar
Karmel Allison committed
25
import tensorflow as tf  # pylint: disable=g-bad-import-order
26

Karmel Allison's avatar
Karmel Allison committed
27
from official.utils.arg_parsers import parsers
28
from official.utils.logs import hooks_helper
29
from official.utils.misc import model_helpers
30

31
32
33
34
35
36
37
38
39
40
_CSV_COLUMNS = [
    'age', 'workclass', 'fnlwgt', 'education', 'education_num',
    'marital_status', 'occupation', 'relationship', 'race', 'gender',
    'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
    'income_bracket'
]

_CSV_COLUMN_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''],
                        [0], [0], [0], [''], ['']]

41
42
43
44
_NUM_EXAMPLES = {
    'train': 32561,
    'validation': 16281,
}
Neal Wu's avatar
Neal Wu committed
45

46

47
48
49
LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'}


50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def build_model_columns():
  """Builds a set of wide and deep feature columns."""
  # Continuous columns
  age = tf.feature_column.numeric_column('age')
  education_num = tf.feature_column.numeric_column('education_num')
  capital_gain = tf.feature_column.numeric_column('capital_gain')
  capital_loss = tf.feature_column.numeric_column('capital_loss')
  hours_per_week = tf.feature_column.numeric_column('hours_per_week')

  education = tf.feature_column.categorical_column_with_vocabulary_list(
      'education', [
          'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college',
          'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school',
          '5th-6th', '10th', '1st-4th', 'Preschool', '12th'])

  marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
      'marital_status', [
          'Married-civ-spouse', 'Divorced', 'Married-spouse-absent',
          'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed'])

  relationship = tf.feature_column.categorical_column_with_vocabulary_list(
      'relationship', [
          'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried',
          'Other-relative'])

  workclass = tf.feature_column.categorical_column_with_vocabulary_list(
      'workclass', [
          'Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov',
          'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'])

  # To show an example of hashing:
  occupation = tf.feature_column.categorical_column_with_hash_bucket(
      'occupation', hash_bucket_size=1000)

  # Transformations.
  age_buckets = tf.feature_column.bucketized_column(
      age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])

  # Wide columns and deep columns.
  base_columns = [
      education, marital_status, relationship, workclass, occupation,
      age_buckets,
  ]

  crossed_columns = [
      tf.feature_column.crossed_column(
          ['education', 'occupation'], hash_bucket_size=1000),
      tf.feature_column.crossed_column(
          [age_buckets, 'education', 'occupation'], hash_bucket_size=1000),
  ]

  wide_columns = base_columns + crossed_columns

  deep_columns = [
      age,
      education_num,
      capital_gain,
      capital_loss,
      hours_per_week,
      tf.feature_column.indicator_column(workclass),
      tf.feature_column.indicator_column(education),
      tf.feature_column.indicator_column(marital_status),
      tf.feature_column.indicator_column(relationship),
      # To show an example of embedding
      tf.feature_column.embedding_column(occupation, dimension=8),
  ]

  return wide_columns, deep_columns


def build_estimator(model_dir, model_type):
  """Build an estimator appropriate for the given model type."""
  wide_columns, deep_columns = build_model_columns()
  hidden_units = [100, 75, 50, 25]

  # Create a tf.estimator.RunConfig to ensure the model is run on CPU, which
  # trains faster than GPU for this model.
  run_config = tf.estimator.RunConfig().replace(
      session_config=tf.ConfigProto(device_count={'GPU': 0}))

  if model_type == 'wide':
    return tf.estimator.LinearClassifier(
        model_dir=model_dir,
        feature_columns=wide_columns,
        config=run_config)
  elif model_type == 'deep':
    return tf.estimator.DNNClassifier(
        model_dir=model_dir,
        feature_columns=deep_columns,
        hidden_units=hidden_units,
        config=run_config)
  else:
    return tf.estimator.DNNLinearCombinedClassifier(
        model_dir=model_dir,
        linear_feature_columns=wide_columns,
        dnn_feature_columns=deep_columns,
        dnn_hidden_units=hidden_units,
        config=run_config)


def input_fn(data_file, num_epochs, shuffle, batch_size):
  """Generate an input function for the Estimator."""
  assert tf.gfile.Exists(data_file), (
153
154
      '%s not found. Please make sure you have run data_download.py and '
      'set the --data_dir argument to the correct path.' % data_file)
Neal Wu's avatar
Neal Wu committed
155

156
157
158
159
160
161
162
163
  def parse_csv(value):
    print('Parsing', data_file)
    columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
    features = dict(zip(_CSV_COLUMNS, columns))
    labels = features.pop('income_bracket')
    return features, tf.equal(labels, '>50K')

  # Extract lines from input files using the Dataset API.
164
  dataset = tf.data.TextLineDataset(data_file)
165

Neal Wu's avatar
Neal Wu committed
166
  if shuffle:
167
    dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])
Neal Wu's avatar
Neal Wu committed
168

169
170
  dataset = dataset.map(parse_csv, num_parallel_calls=5)

Neal Wu's avatar
Neal Wu committed
171
172
  # We call repeat after shuffling, rather than before, to prevent separate
  # epochs from blending together.
173
  dataset = dataset.repeat(num_epochs)
Neal Wu's avatar
Neal Wu committed
174
  dataset = dataset.batch(batch_size)
175
  return dataset
176

177

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def export_model(model, model_type, export_dir):
  """Export to SavedModel format.

  Args:
    model: Estimator object
    model_type: string indicating model type. "wide", "deep" or "wide_deep"
    export_dir: directory to export the model.
  """
  wide_columns, deep_columns = build_model_columns()
  if model_type == 'wide':
    columns = wide_columns
  elif model_type == 'deep':
    columns = deep_columns
  else:
    columns = wide_columns + deep_columns
  feature_spec = tf.feature_column.make_parse_example_spec(columns)
  example_input_fn = (
      tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec))
  model.export_savedmodel(export_dir, example_input_fn)


199
200
201
202
def main(argv):
  parser = WideDeepArgParser()
  flags = parser.parse_args(args=argv[1:])

203
  # Clean up the model directory if present
204
205
  shutil.rmtree(flags.model_dir, ignore_errors=True)
  model = build_estimator(flags.model_dir, flags.model_type)
206

207
208
  train_file = os.path.join(flags.data_dir, 'adult.data')
  test_file = os.path.join(flags.data_dir, 'adult.test')
209

210
  # Train and evaluate the model every `flags.epochs_between_evals` epochs.
Karmel Allison's avatar
Karmel Allison committed
211
  def train_input_fn():
Katherine Wu's avatar
Katherine Wu committed
212
213
    return input_fn(
        train_file, flags.epochs_between_evals, True, flags.batch_size)
Karmel Allison's avatar
Karmel Allison committed
214
215

  def eval_input_fn():
216
    return input_fn(test_file, 1, False, flags.batch_size)
Karmel Allison's avatar
Karmel Allison committed
217

218
  loss_prefix = LOSS_PREFIX.get(flags.model_type, '')
219
  train_hooks = hooks_helper.get_train_hooks(
220
      flags.hooks, batch_size=flags.batch_size,
221
222
      tensors_to_log={'average_loss': loss_prefix + 'head/truediv',
                      'loss': loss_prefix + 'head/weighted_loss/Sum'})
223

224
  # Train and evaluate the model every `flags.epochs_between_evals` epochs.
225
  for n in range(flags.train_epochs // flags.epochs_between_evals):
Karmel Allison's avatar
Karmel Allison committed
226
227
    model.train(input_fn=train_input_fn, hooks=train_hooks)
    results = model.evaluate(input_fn=eval_input_fn)
228
229

    # Display evaluation metrics
230
    print('Results at epoch', (n + 1) * flags.epochs_between_evals)
231
    print('-' * 60)
Neal Wu's avatar
Neal Wu committed
232

233
234
235
    for key in sorted(results):
      print('%s: %s' % (key, results[key]))

236
237
238
239
    if model_helpers.past_stop_threshold(
        flags.stop_threshold, results['accuracy']):
      break

240
241
242
243
  # Export the model
  if flags.export_dir is not None:
    export_model(model, flags.model_type, flags.export_dir)

244

245
246
class WideDeepArgParser(argparse.ArgumentParser):
  """Argument parser for running the wide deep model."""
Karmel Allison's avatar
Karmel Allison committed
247

248
  def __init__(self):
249
    super(WideDeepArgParser, self).__init__(parents=[parsers.BaseParser()])
250
251
252
253
254
255
256
257
258
259
260
261
262
    self.add_argument(
        '--model_type', '-mt', type=str, default='wide_deep',
        choices=['wide', 'deep', 'wide_deep'],
        help='[default %(default)s] Valid model types: wide, deep, wide_deep.',
        metavar='<MT>')
    self.set_defaults(
        data_dir='/tmp/census_data',
        model_dir='/tmp/census_model',
        train_epochs=40,
        epochs_between_evals=2,
        batch_size=40)


263
264
if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
265
  main(argv=sys.argv)