wide_deep.py 9.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
田传武's avatar
田传武 committed
15
"""Example code for TensorFlow Wide & Deep Tutorial using tf.estimator API."""
16
17
18
19
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

20
import os
21
22
import shutil

23
24
from absl import app as absl_app
from absl import flags
Karmel Allison's avatar
Karmel Allison committed
25
import tensorflow as tf  # pylint: disable=g-bad-import-order
26

27
from official.utils.flags import core as flags_core
28
from official.utils.logs import hooks_helper
29
from official.utils.logs import logger
30
from official.utils.misc import model_helpers
31

32

33
34
35
36
37
38
39
40
41
42
_CSV_COLUMNS = [
    'age', 'workclass', 'fnlwgt', 'education', 'education_num',
    'marital_status', 'occupation', 'relationship', 'race', 'gender',
    'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
    'income_bracket'
]

_CSV_COLUMN_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''],
                        [0], [0], [0], [''], ['']]

43
44
45
46
_NUM_EXAMPLES = {
    'train': 32561,
    'validation': 16281,
}
Neal Wu's avatar
Neal Wu committed
47

48

49
50
51
LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'}


52
53
54
def define_wide_deep_flags():
  """Add supervised learning flags, as well as wide-deep model type."""
  flags_core.define_base()
55
  flags_core.define_benchmark()
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

  flags.adopt_module_key_flags(flags_core)

  flags.DEFINE_enum(
      name="model_type", short_name="mt", default="wide_deep",
      enum_values=['wide', 'deep', 'wide_deep'],
      help="Select model topology.")

  flags_core.set_defaults(data_dir='/tmp/census_data',
                          model_dir='/tmp/census_model',
                          train_epochs=40,
                          epochs_between_evals=2,
                          batch_size=40)


71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def build_model_columns():
  """Builds a set of wide and deep feature columns."""
  # Continuous columns
  age = tf.feature_column.numeric_column('age')
  education_num = tf.feature_column.numeric_column('education_num')
  capital_gain = tf.feature_column.numeric_column('capital_gain')
  capital_loss = tf.feature_column.numeric_column('capital_loss')
  hours_per_week = tf.feature_column.numeric_column('hours_per_week')

  education = tf.feature_column.categorical_column_with_vocabulary_list(
      'education', [
          'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college',
          'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school',
          '5th-6th', '10th', '1st-4th', 'Preschool', '12th'])

  marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
      'marital_status', [
          'Married-civ-spouse', 'Divorced', 'Married-spouse-absent',
          'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed'])

  relationship = tf.feature_column.categorical_column_with_vocabulary_list(
      'relationship', [
          'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried',
          'Other-relative'])

  workclass = tf.feature_column.categorical_column_with_vocabulary_list(
      'workclass', [
          'Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov',
          'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'])

  # To show an example of hashing:
  occupation = tf.feature_column.categorical_column_with_hash_bucket(
      'occupation', hash_bucket_size=1000)

  # Transformations.
  age_buckets = tf.feature_column.bucketized_column(
      age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])

  # Wide columns and deep columns.
  base_columns = [
      education, marital_status, relationship, workclass, occupation,
      age_buckets,
  ]

  crossed_columns = [
      tf.feature_column.crossed_column(
          ['education', 'occupation'], hash_bucket_size=1000),
      tf.feature_column.crossed_column(
          [age_buckets, 'education', 'occupation'], hash_bucket_size=1000),
  ]

  wide_columns = base_columns + crossed_columns

  deep_columns = [
      age,
      education_num,
      capital_gain,
      capital_loss,
      hours_per_week,
      tf.feature_column.indicator_column(workclass),
      tf.feature_column.indicator_column(education),
      tf.feature_column.indicator_column(marital_status),
      tf.feature_column.indicator_column(relationship),
      # To show an example of embedding
      tf.feature_column.embedding_column(occupation, dimension=8),
  ]

  return wide_columns, deep_columns


def build_estimator(model_dir, model_type):
  """Build an estimator appropriate for the given model type."""
  wide_columns, deep_columns = build_model_columns()
  hidden_units = [100, 75, 50, 25]

  # Create a tf.estimator.RunConfig to ensure the model is run on CPU, which
  # trains faster than GPU for this model.
  run_config = tf.estimator.RunConfig().replace(
      session_config=tf.ConfigProto(device_count={'GPU': 0}))

  if model_type == 'wide':
    return tf.estimator.LinearClassifier(
        model_dir=model_dir,
        feature_columns=wide_columns,
        config=run_config)
  elif model_type == 'deep':
    return tf.estimator.DNNClassifier(
        model_dir=model_dir,
        feature_columns=deep_columns,
        hidden_units=hidden_units,
        config=run_config)
  else:
    return tf.estimator.DNNLinearCombinedClassifier(
        model_dir=model_dir,
        linear_feature_columns=wide_columns,
        dnn_feature_columns=deep_columns,
        dnn_hidden_units=hidden_units,
        config=run_config)


def input_fn(data_file, num_epochs, shuffle, batch_size):
  """Generate an input function for the Estimator."""
  assert tf.gfile.Exists(data_file), (
174
175
      '%s not found. Please make sure you have run data_download.py and '
      'set the --data_dir argument to the correct path.' % data_file)
Neal Wu's avatar
Neal Wu committed
176

177
178
179
180
181
182
183
184
  def parse_csv(value):
    print('Parsing', data_file)
    columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
    features = dict(zip(_CSV_COLUMNS, columns))
    labels = features.pop('income_bracket')
    return features, tf.equal(labels, '>50K')

  # Extract lines from input files using the Dataset API.
185
  dataset = tf.data.TextLineDataset(data_file)
186

Neal Wu's avatar
Neal Wu committed
187
  if shuffle:
188
    dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])
Neal Wu's avatar
Neal Wu committed
189

190
191
  dataset = dataset.map(parse_csv, num_parallel_calls=5)

Neal Wu's avatar
Neal Wu committed
192
193
  # We call repeat after shuffling, rather than before, to prevent separate
  # epochs from blending together.
194
  dataset = dataset.repeat(num_epochs)
Neal Wu's avatar
Neal Wu committed
195
  dataset = dataset.batch(batch_size)
196
  return dataset
197

198

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def export_model(model, model_type, export_dir):
  """Export to SavedModel format.

  Args:
    model: Estimator object
    model_type: string indicating model type. "wide", "deep" or "wide_deep"
    export_dir: directory to export the model.
  """
  wide_columns, deep_columns = build_model_columns()
  if model_type == 'wide':
    columns = wide_columns
  elif model_type == 'deep':
    columns = deep_columns
  else:
    columns = wide_columns + deep_columns
  feature_spec = tf.feature_column.make_parse_example_spec(columns)
  example_input_fn = (
      tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec))
  model.export_savedmodel(export_dir, example_input_fn)


220
221
222
223
224
225
226
def run_wide_deep(flags_obj):
  """Run Wide-Deep training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.
  """

227
  # Clean up the model directory if present
228
229
  shutil.rmtree(flags_obj.model_dir, ignore_errors=True)
  model = build_estimator(flags_obj.model_dir, flags_obj.model_type)
230

231
232
  train_file = os.path.join(flags_obj.data_dir, 'adult.data')
  test_file = os.path.join(flags_obj.data_dir, 'adult.test')
233

234
  # Train and evaluate the model every `flags.epochs_between_evals` epochs.
Karmel Allison's avatar
Karmel Allison committed
235
  def train_input_fn():
Katherine Wu's avatar
Katherine Wu committed
236
    return input_fn(
237
        train_file, flags_obj.epochs_between_evals, True, flags_obj.batch_size)
Karmel Allison's avatar
Karmel Allison committed
238
239

  def eval_input_fn():
240
    return input_fn(test_file, 1, False, flags_obj.batch_size)
Karmel Allison's avatar
Karmel Allison committed
241

242
243
244
245
246
247
  run_params = {
      'batch_size': flags_obj.batch_size,
      'train_epochs': flags_obj.train_epochs,
      'model_type': flags_obj.model_type,
  }

248
  benchmark_logger = logger.get_benchmark_logger()
249
250
  benchmark_logger.log_run_info('wide_deep', 'Census Income', run_params,
                                test_id=flags_obj.benchmark_test_id)
251

252
  loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '')
253
  train_hooks = hooks_helper.get_train_hooks(
254
      flags_obj.hooks, batch_size=flags_obj.batch_size,
255
256
      tensors_to_log={'average_loss': loss_prefix + 'head/truediv',
                      'loss': loss_prefix + 'head/weighted_loss/Sum'})
257

258
  # Train and evaluate the model every `flags.epochs_between_evals` epochs.
259
  for n in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
Karmel Allison's avatar
Karmel Allison committed
260
261
    model.train(input_fn=train_input_fn, hooks=train_hooks)
    results = model.evaluate(input_fn=eval_input_fn)
262
263

    # Display evaluation metrics
264
265
266
267
    tf.logging.info('Results at epoch %d / %d',
                    (n + 1) * flags_obj.epochs_between_evals,
                    flags_obj.train_epochs)
    tf.logging.info('-' * 60)
Neal Wu's avatar
Neal Wu committed
268

269
    for key in sorted(results):
270
271
272
      tf.logging.info('%s: %s' % (key, results[key]))

    benchmark_logger.log_evaluation_result(results)
273

274
    if model_helpers.past_stop_threshold(
275
        flags_obj.stop_threshold, results['accuracy']):
276
277
      break

278
  # Export the model
279
280
  if flags_obj.export_dir is not None:
    export_model(model, flags_obj.model_type, flags_obj.export_dir)
281
282


283
def main(_):
284
285
  with logger.benchmark_context(flags.FLAGS):
    run_wide_deep(flags.FLAGS)
286
287


288
289
if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
290
291
  define_wide_deep_flags()
  absl_app.run(main)