Commit 9dafea91 authored by sunxx1's avatar sunxx1
Browse files

Merge branch 'qianyj_tf' into 'main'

update tf code

See merge request dcutoolkit/deeplearing/dlexamples_new!35
parents 92a2ca36 a4146470
[32, 8, 8, 4, 0.23128163814544678, 0.22117376327514648, 4100.51806640625, 32, 8, 8, 4, 1.1768392324447632, 0.2728465795516968, 5832.6416015625]
\ No newline at end of file
[32, 8, 8, 4, 0.7616699934005737, 0.5485763549804688, 4106.8720703125, 32, 8, 8, 4, -0.056346118450164795, 0.5792689919471741, 2972.37255859375]
\ No newline at end of file
[32, 16, 16, 3, 0.9722558259963989, 0.18413543701171875, 12374.20703125, 32, 16, 16, 3, 1.6126631498336792, -1.096894383430481, -0.041595458984375]
\ No newline at end of file
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This module tests generic behavior of reference data tests.
This test is not intended to test every layer of interest, and models should
test the layers that affect them. This test is primarily focused on ensuring
that reference_data.BaseTest functions as intended. If there is a legitimate
change such as a change to TensorFlow which changes graph construction, tests
can be regenerated with the following command:
$ python3 reference_data_test.py -regen
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import unittest
import warnings
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.testing import reference_data
class GoldenBaseTest(reference_data.BaseTest):
"""Class to ensure that reference data testing runs properly."""
@property
def test_name(self):
return "reference_data_test"
def _uniform_random_ops(self, test=False, wrong_name=False, wrong_shape=False,
bad_seed=False, bad_function=False):
"""Tests number generation and failure modes.
This test is of a very simple graph: the generation of a 1x1 random tensor.
However, it is also used to confirm that the tests are actually checking
properly by failing in predefined ways.
Args:
test: Whether or not to run as a test case.
wrong_name: Whether to assign the wrong name to the tensor.
wrong_shape: Whether to create a tensor with the wrong shape.
bad_seed: Whether or not to perturb the random seed.
bad_function: Whether to perturb the correctness function.
"""
name = "uniform_random"
g = tf.Graph()
with g.as_default():
seed = self.name_to_seed(name)
seed = seed + 1 if bad_seed else seed
tf.set_random_seed(seed)
tensor_name = "wrong_tensor" if wrong_name else "input_tensor"
tensor_shape = (1, 2) if wrong_shape else (1, 1)
input_tensor = tf.get_variable(
tensor_name, dtype=tf.float32,
initializer=tf.random_uniform(tensor_shape, maxval=1)
)
def correctness_function(tensor_result):
result = float(tensor_result[0, 0])
result = result + 0.1 if bad_function else result
return [result]
self._save_or_test_ops(
name=name, graph=g, ops_to_eval=[input_tensor], test=test,
correctness_function=correctness_function
)
def _dense_ops(self, test=False):
name = "dense"
g = tf.Graph()
with g.as_default():
tf.set_random_seed(self.name_to_seed(name))
input_tensor = tf.get_variable(
"input_tensor", dtype=tf.float32,
initializer=tf.random_uniform((1, 2), maxval=1)
)
layer = tf.layers.dense(inputs=input_tensor, units=4)
layer = tf.layers.dense(inputs=layer, units=1)
self._save_or_test_ops(
name=name, graph=g, ops_to_eval=[layer], test=test,
correctness_function=self.default_correctness_function
)
def test_uniform_random(self):
self._uniform_random_ops(test=True)
def test_tensor_name_error(self):
with self.assertRaises(AssertionError):
self._uniform_random_ops(test=True, wrong_name=True)
def test_tensor_shape_error(self):
with self.assertRaises(AssertionError):
self._uniform_random_ops(test=True, wrong_shape=True)
@unittest.skipIf(sys.version_info[0] == 2,
"catch_warning doesn't catch tf.logging.warn in py 2.")
def test_bad_seed(self):
with warnings.catch_warnings(record=True) as warn_catch:
self._uniform_random_ops(test=True, bad_seed=True)
assert len(warn_catch) == 1, "Test did not warn of minor graph change."
def test_incorrectness_function(self):
with self.assertRaises(AssertionError):
self._uniform_random_ops(test=True, bad_function=True)
def test_dense(self):
self._dense_ops(test=True)
def regenerate(self):
self._uniform_random_ops(test=False)
self._dense_ops(test=False)
if __name__ == "__main__":
reference_data.main(argv=sys.argv, test_class=GoldenBaseTest)
#!/bin/bash
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Presubmit script that run tests and lint under local environment.
# Make sure that tensorflow and pylint is installed.
# usage: models >: ./official/utils/testing/scripts/presubmit.sh
# usage: models >: ./official/utils/testing/scripts/presubmit.sh lint py2_test py3_test
set +x
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR/../../../.."
MODEL_ROOT="$(pwd)"
export PYTHONPATH="$PYTHONPATH:${MODEL_ROOT}"
cd official
lint() {
local exit_code=0
RC_FILE="utils/testing/pylint.rcfile"
PROTO_SKIP="DO\sNOT\sEDIT!"
echo "===========Running lint test============"
for file in `find . -name '*.py' ! -name '*test.py' -print`
do
if grep ${PROTO_SKIP} ${file}; then
echo "Linting ${file} (Skipped: Machine generated file)"
else
echo "Linting ${file}"
pylint --rcfile="${RC_FILE}" "${file}" || exit_code=$?
fi
done
# More lenient for test files.
for file in `find . -name '*test.py' -print`
do
echo "Linting ${file}"
pylint --rcfile="${RC_FILE}" --disable=missing-docstring,protected-access "${file}" || exit_code=$?
done
return "${exit_code}"
}
py_test() {
local PY_BINARY="$1"
local exit_code=0
echo "===========Running Python test============"
for test_file in `find . -name '*test.py' -print`
do
echo "Testing ${test_file}"
${PY_BINARY} "${test_file}" || exit_code=$?
done
return "${exit_code}"
}
py2_test() {
local PY_BINARY=$(which python2)
py_test "$PY_BINARY"
return $?
}
py3_test() {
local PY_BINARY=$(which python3)
py_test "$PY_BINARY"
return $?
}
test_result=0
if [ "$#" -eq 0 ]; then
TESTS="lint py2_test py3_test"
else
TESTS="$@"
fi
for t in "${TESTS}"; do
${t} || test_result=$?
done
exit "${test_result}"
# Predicting Income with the Census Income Dataset
## Overview
The [Census Income Data Set](https://archive.ics.uci.edu/ml/datasets/Census+Income) contains over 48,000 samples with attributes including age, occupation, education, and income (a binary label, either `>50K` or `<=50K`). The dataset is split into roughly 32,000 training and 16,000 testing samples.
Here, we use the [wide and deep model](https://research.googleblog.com/2016/06/wide-deep-learning-better-together-with.html) to predict the income labels. The **wide model** is able to memorize interactions with data with a large number of features but not able to generalize these learned interactions on new data. The **deep model** generalizes well but is unable to learn exceptions within the data. The **wide and deep model** combines the two models and is able to generalize while learning exceptions.
For the purposes of this example code, the Census Income Data Set was chosen to allow the model to train in a reasonable amount of time. You'll notice that the deep model performs almost as well as the wide and deep model on this dataset. The wide and deep model truly shines on larger data sets with high-cardinality features, where each feature has millions/billions of unique possible values (which is the specialty of the wide model).
Finally, a key point. As a modeler and developer, think about how this dataset is used and the potential benefits and harm a model's predictions can cause. A model like this could reinforce societal biases and disparities. Is a feature relevant to the problem you want to solve, or will it introduce bias? For more information, read about [ML fairness](https://developers.google.com/machine-learning/fairness-overview/).
---
The code sample in this directory uses the high level `tf.estimator.Estimator` API. This API is great for fast iteration and quickly adapting models to your own datasets without major code overhauls. It allows you to move from single-worker training to distributed training, and it makes it easy to export model binaries for prediction.
The input function for the `Estimator` uses `tf.contrib.data.TextLineDataset`, which creates a `Dataset` object. The `Dataset` API makes it easy to apply transformations (map, batch, shuffle, etc.) to the data. [Read more here](https://www.tensorflow.org/guide/datasets).
The `Estimator` and `Dataset` APIs are both highly encouraged for fast development and efficient training.
## Running the code
First make sure you've [added the models folder to your Python path](/official/#running-the-models); otherwise you may encounter an error like `ImportError: No module named official.wide_deep`.
### Setup
The [Census Income Data Set](https://archive.ics.uci.edu/ml/datasets/Census+Income) that this sample uses for training is hosted by the [UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/). We have provided a script that downloads and cleans the necessary files.
```
python census_dataset.py
```
This will download the files to `/tmp/census_data`. To change the directory, set the `--data_dir` flag.
### Training
You can run the code locally as follows:
```
python census_main.py
```
The model is saved to `/tmp/census_model` by default, which can be changed using the `--model_dir` flag.
To run the *wide* or *deep*-only models, set the `--model_type` flag to `wide` or `deep`. Other flags are configurable as well; see `census_main.py` for details.
The final accuracy should be over 83% with any of the three model types.
You can also experiment with `-inter` and `-intra` flag to explore inter/intra op parallelism for potential better performance as follows:
```
python census_main.py --inter=<int> --intra=<int>
```
Please note the above optional inter/intra op does not affect model accuracy. These are TensorFlow framework configurations that only affect execution time.
For more details regarding the above inter/intra flags, please refer to [Optimizing_for_CPU](https://www.tensorflow.org/performance/performance_guide#optimizing_for_cpu) or [TensorFlow config.proto source code](https://github.com/tensorflow/tensorflow/blob/26b4dfa65d360f2793ad75083c797d57f8661b93/tensorflow/core/protobuf/config.proto#L165).
### TensorBoard
Run TensorBoard to inspect the details about the graph and training progression.
```
tensorboard --logdir=/tmp/census_model
```
## Inference with SavedModel
You can export the model into Tensorflow [SavedModel](https://www.tensorflow.org/guide/saved_model) format by using the argument `--export_dir`:
```
python census_main.py --export_dir /tmp/wide_deep_saved_model
```
After the model finishes training, use [`saved_model_cli`](https://www.tensorflow.org/guide/saved_model#cli_to_inspect_and_execute_savedmodel) to inspect and execute the SavedModel.
Try the following commands to inspect the SavedModel:
**Replace `${TIMESTAMP}` with the folder produced (e.g. 1524249124)**
```
# List possible tag_sets. Only one metagraph is saved, so there will be one option.
saved_model_cli show --dir /tmp/wide_deep_saved_model/${TIMESTAMP}/
# Show SignatureDefs for tag_set=serve. SignatureDefs define the outputs to show.
saved_model_cli show --dir /tmp/wide_deep_saved_model/${TIMESTAMP}/ \
--tag_set serve --all
```
### Inference
Let's use the model to predict the income group of two examples:
```
saved_model_cli run --dir /tmp/wide_deep_saved_model/${TIMESTAMP}/ \
--tag_set serve --signature_def="predict" \
--input_examples='examples=[{"age":[46.], "education_num":[10.], "capital_gain":[7688.], "capital_loss":[0.], "hours_per_week":[38.]}, {"age":[24.], "education_num":[13.], "capital_gain":[0.], "capital_loss":[0.], "hours_per_week":[50.]}]'
```
This will print out the predicted classes and class probabilities. Class 0 is the <=50k group and 1 is the >50k group.
## Additional Links
If you are interested in distributed training, take a look at [Distributed TensorFlow](https://www.tensorflow.org/deploy/distributed).
You can also [run this model on Cloud ML Engine](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction), which provides [hyperparameter tuning](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction#hyperparameter_tuning) to maximize your model's results and enables [deploying your model for prediction](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction#deploy_a_model_to_support_prediction).
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Download and clean the Census Income Dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
# pylint: disable=wrong-import-order
from absl import app as absl_app
from absl import flags
from six.moves import urllib
import tensorflow as tf
# pylint: enable=wrong-import-order
from official.utils.flags import core as flags_core
DATA_URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult'
TRAINING_FILE = 'adult.data'
TRAINING_URL = '%s/%s' % (DATA_URL, TRAINING_FILE)
EVAL_FILE = 'adult.test'
EVAL_URL = '%s/%s' % (DATA_URL, EVAL_FILE)
_CSV_COLUMNS = [
'age', 'workclass', 'fnlwgt', 'education', 'education_num',
'marital_status', 'occupation', 'relationship', 'race', 'gender',
'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
'income_bracket'
]
_CSV_COLUMN_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''],
[0], [0], [0], [''], ['']]
_HASH_BUCKET_SIZE = 1000
_NUM_EXAMPLES = {
'train': 32561,
'validation': 16281,
}
def _download_and_clean_file(filename, url):
"""Downloads data from url, and makes changes to match the CSV format."""
temp_file, _ = urllib.request.urlretrieve(url)
with tf.gfile.Open(temp_file, 'r') as temp_eval_file:
with tf.gfile.Open(filename, 'w') as eval_file:
for line in temp_eval_file:
line = line.strip()
line = line.replace(', ', ',')
if not line or ',' not in line:
continue
if line[-1] == '.':
line = line[:-1]
line += '\n'
eval_file.write(line)
tf.gfile.Remove(temp_file)
def download(data_dir):
"""Download census data if it is not already present."""
tf.gfile.MakeDirs(data_dir)
training_file_path = os.path.join(data_dir, TRAINING_FILE)
if not tf.gfile.Exists(training_file_path):
_download_and_clean_file(training_file_path, TRAINING_URL)
eval_file_path = os.path.join(data_dir, EVAL_FILE)
if not tf.gfile.Exists(eval_file_path):
_download_and_clean_file(eval_file_path, EVAL_URL)
def build_model_columns():
"""Builds a set of wide and deep feature columns."""
# Continuous variable columns
age = tf.feature_column.numeric_column('age')
education_num = tf.feature_column.numeric_column('education_num')
capital_gain = tf.feature_column.numeric_column('capital_gain')
capital_loss = tf.feature_column.numeric_column('capital_loss')
hours_per_week = tf.feature_column.numeric_column('hours_per_week')
education = tf.feature_column.categorical_column_with_vocabulary_list(
'education', [
'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college',
'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school',
'5th-6th', '10th', '1st-4th', 'Preschool', '12th'])
marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
'marital_status', [
'Married-civ-spouse', 'Divorced', 'Married-spouse-absent',
'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed'])
relationship = tf.feature_column.categorical_column_with_vocabulary_list(
'relationship', [
'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried',
'Other-relative'])
workclass = tf.feature_column.categorical_column_with_vocabulary_list(
'workclass', [
'Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov',
'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'])
# To show an example of hashing:
occupation = tf.feature_column.categorical_column_with_hash_bucket(
'occupation', hash_bucket_size=_HASH_BUCKET_SIZE)
# Transformations.
age_buckets = tf.feature_column.bucketized_column(
age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
# Wide columns and deep columns.
base_columns = [
education, marital_status, relationship, workclass, occupation,
age_buckets,
]
crossed_columns = [
tf.feature_column.crossed_column(
['education', 'occupation'], hash_bucket_size=_HASH_BUCKET_SIZE),
tf.feature_column.crossed_column(
[age_buckets, 'education', 'occupation'],
hash_bucket_size=_HASH_BUCKET_SIZE),
]
wide_columns = base_columns + crossed_columns
deep_columns = [
age,
education_num,
capital_gain,
capital_loss,
hours_per_week,
tf.feature_column.indicator_column(workclass),
tf.feature_column.indicator_column(education),
tf.feature_column.indicator_column(marital_status),
tf.feature_column.indicator_column(relationship),
# To show an example of embedding
tf.feature_column.embedding_column(occupation, dimension=8),
]
return wide_columns, deep_columns
def input_fn(data_file, num_epochs, shuffle, batch_size):
"""Generate an input function for the Estimator."""
assert tf.gfile.Exists(data_file), (
'%s not found. Please make sure you have run census_dataset.py and '
'set the --data_dir argument to the correct path.' % data_file)
def parse_csv(value):
tf.logging.info('Parsing {}'.format(data_file))
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
features = dict(zip(_CSV_COLUMNS, columns))
labels = features.pop('income_bracket')
classes = tf.equal(labels, '>50K') # binary classification
return features, classes
# Extract lines from input files using the Dataset API.
dataset = tf.data.TextLineDataset(data_file)
if shuffle:
dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])
dataset = dataset.map(parse_csv, num_parallel_calls=5)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
return dataset
def define_data_download_flags():
"""Add flags specifying data download arguments."""
flags.DEFINE_string(
name="data_dir", default="/tmp/census_data/",
help=flags_core.help_wrap(
"Directory to download and extract data."))
def main(_):
download(flags.FLAGS.data_dir)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
define_data_download_flags()
absl_app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment