Commit 3b15ca78 authored by Younghee Kwon's avatar Younghee Kwon Committed by Taylor Robie
Browse files

Adds boosted_trees to the official models (#4074)

* Add boosted_trees to the official models

* Comments addressed from review, and a test added; using absl.flags instead of argparser.

* Used help_wrap. Also added instructions for inference.
parent 499071ef
...@@ -12,6 +12,8 @@ If you are on a version of TensorFlow earlier than 1.4, please [update your inst ...@@ -12,6 +12,8 @@ If you are on a version of TensorFlow earlier than 1.4, please [update your inst
Below is a list of the models available. Below is a list of the models available.
[boosted_trees](boosted_trees): A Gradient Boosted Trees model to classify higgs boson process from HIGGS Data Set.
[mnist](mnist): A basic model to classify digits from the MNIST dataset. [mnist](mnist): A basic model to classify digits from the MNIST dataset.
[resnet](resnet): A deep residual network that can be used to classify both CIFAR-10 and ImageNet's dataset of 1000 classes. [resnet](resnet): A deep residual network that can be used to classify both CIFAR-10 and ImageNet's dataset of 1000 classes.
......
# Classifying Higgs boson processes in the HIGGS Data Set
## Overview
The [HIGGS Data Set](https://archive.ics.uci.edu/ml/datasets/HIGGS) contains 11 million samples with 28 features, and is for the classification problem to distinguish between a signal process which produces Higgs bosons and a background process which does not.
We use Gradient Boosted Trees algorithm to distinguish the two classes.
---
The code sample uses the high level `tf.estimator.Estimator` and `tf.data.Dataset`. These APIs are great for fast iteration and quickly adapting models to your own datasets without major code overhauls. It allows you to move from single-worker training to distributed training, and makes it easy to export model binaries for prediction. Here, for further simplicity and faster execution, we use a utility function `tf.contrib.estimator.boosted_trees_classifier_train_in_memory`. This utility function is especially effective when the input is provided as in-memory data sets like numpy arrays.
An input function for the `Estimator` typically uses `tf.data.Dataset` API, which can handle various data control like streaming, batching, transform and shuffling. However `boosted_trees_classifier_train_in_memory()` utility function requires that the entire data is provided as a single batch (i.e. without using `batch()` API). Thus in this practice, simply `Dataset.from_tensors()` is used to convert numpy arrays into structured tensors, and `Dataset.zip()` is used to put features and label together.
For further references of `Dataset`, [Read more here](https://www.tensorflow.org/programmers_guide/datasets).
## Running the code
First make sure you've [added the models folder to your Python path](/official/#running-the-models); otherwise you may encounter an error like `ImportError: No module named official.boosted_trees`.
### Setup
The [HIGGS Data Set](https://archive.ics.uci.edu/ml/datasets/HIGGS) that this sample uses for training is hosted by the [UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/). We have provided a script that downloads and cleans the necessary files.
```
python data_download.py
```
This will download a file and store the processed file under the directory designated by `--data_dir` (defaults to `/tmp/higgs_data/`). To change the target directory, set the `--data_dir` flag. The directory could be network storages that Tensorflow supports (like Google Cloud Storage, `gs://<bucket>/<path>/`).
The file downloaded to the local temporary folder is about 2.8 GB, and the processed file is about 0.8 GB, so there should be enough storage to handle them.
### Training
This example uses about 3 GB of RAM during training.
You can run the code locally as follows:
```
python train_higgs.py
```
The model is by default saved to `/tmp/higgs_model`, which can be changed using the `--model_dir` flag.
Note that the model_dir is cleaned up before every time training starts.
Model parameters can be adjusted by flags, like `--n_trees`, `--max_depth`, `--learning_rate` and so on. Check out the code for details.
The final accuacy will be around 74% and loss will be around 0.516 over the eval set, when trained with the default parameters.
By default, the first 1 million examples among 11 millions are used for training, and the last 1 million examples are used for evaluation.
The training/evaluation data can be selected as index ranges by flags `--train_start`, `--train_count`, `--eval_start`, `--eval_count`, etc.
### TensorBoard
Run TensorBoard to inspect the details about the graph and training progression.
```
tensorboard --logdir=/tmp/higgs_model # set logdir as --model_dir set during training.
```
## Inference with SavedModel
You can export the model into Tensorflow [SavedModel](https://www.tensorflow.org/programmers_guide/saved_model) format by using the argument `--export_dir`:
```
python train_higgs.py --export_dir /tmp/higgs_boosted_trees_saved_model
```
After the model finishes training, use [`saved_model_cli`](https://www.tensorflow.org/programmers_guide/saved_model#cli_to_inspect_and_execute_savedmodel) to inspect and execute the SavedModel.
Try the following commands to inspect the SavedModel:
**Replace `${TIMESTAMP}` with the folder produced (e.g. 1524249124)**
```
# List possible tag_sets. Only one metagraph is saved, so there will be one option.
saved_model_cli show --dir /tmp/higgs_boosted_trees_saved_model/${TIMESTAMP}/
# Show SignatureDefs for tag_set=serve. SignatureDefs define the outputs to show.
saved_model_cli show --dir /tmp/higgs_boosted_trees_saved_model/${TIMESTAMP}/ \
--tag_set serve --all
```
### Inference
Let's use the model to predict the income group of two examples:
```
saved_model_cli run --dir /tmp/boosted_trees_higgs_saved_model/${TIMESTAMP}/ \
--tag_set serve --signature_def="predict" \
--input_examples='examples=[{"feature_01":[0.8692932],"feature_02":[-0.6350818],"feature_03":[0.2256903],"feature_04":[0.3274701],"feature_05":[-0.6899932],"feature_06":[0.7542022],"feature_07":[-0.2485731],"feature_08":[-1.0920639],"feature_09":[0.0],"feature_10":[1.3749921],"feature_11":[-0.6536742],"feature_12":[0.9303491],"feature_13":[1.1074361],"feature_14":[1.1389043],"feature_15":[-1.5781983],"feature_16":[-1.0469854],"feature_17":[0.0],"feature_18":[0.6579295],"feature_19":[-0.0104546],"feature_20":[-0.0457672],"feature_21":[3.1019614],"feature_22":[1.3537600],"feature_23":[0.9795631],"feature_24":[0.9780762],"feature_25":[0.9200048],"feature_26":[0.7216575],"feature_27":[0.9887509],"feature_28":[0.8766783]}, {"feature_01":[1.5958393],"feature_02":[-0.6078107],"feature_03":[0.0070749],"feature_04":[1.8184496],"feature_05":[-0.1119060],"feature_06":[0.8475499],"feature_07":[-0.5664370],"feature_08":[1.5812393],"feature_09":[2.1730762],"feature_10":[0.7554210],"feature_11":[0.6431096],"feature_12":[1.4263668],"feature_13":[0.0],"feature_14":[0.9216608],"feature_15":[-1.1904324],"feature_16":[-1.6155890],"feature_17":[0.0],"feature_18":[0.6511141],"feature_19":[-0.6542270],"feature_20":[-1.2743449],"feature_21":[3.1019614],"feature_22":[0.8237606],"feature_23":[0.9381914],"feature_24":[0.9717582],"feature_25":[0.7891763],"feature_26":[0.4305533],"feature_27":[0.9613569],"feature_28":[0.9578179]}]'
```
This will print out the predicted classes and class probabilities.
## Additional Links
If you are interested in distributed training, take a look at [Distributed TensorFlow](https://www.tensorflow.org/deploy/distributed).
You can also [train models on Cloud ML Engine](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction), which provides [hyperparameter tuning](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction#hyperparameter_tuning) to maximize your model's results and enables [deploying your model for prediction](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction#deploy_a_model_to_support_prediction).
"""Downloads the UCI HIGGS Dataset and prepares train data.
The details on the dataset are in https://archive.ics.uci.edu/ml/datasets/HIGGS
It takes a while as it needs to download 2.8 GB over the network, process, then
store it into the specified location as a compressed numpy file.
Usage:
$ python data_download.py --data_dir=/tmp/higgs_data
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import tempfile
import numpy as np
import pandas as pd
from six.moves import urllib
import tensorflow as tf
URL_ROOT = 'https://archive.ics.uci.edu/ml/machine-learning-databases/00280'
INPUT_FILE = 'HIGGS.csv.gz'
NPZ_FILE = 'HIGGS.csv.gz.npz' # numpy compressed file to contain 'data' array.
def parse_args():
"""Parses arguments and returns a tuple (known_args, unparsed_args)."""
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_dir', type=str, default='/tmp/higgs_data',
help='Directory to download higgs dataset and store training/eval data.')
return parser.parse_known_args()
def _download_higgs_data_and_save_npz(data_dir):
"""Download higgs data and store as a numpy compressed file."""
input_url = os.path.join(URL_ROOT, INPUT_FILE)
np_filename = os.path.join(data_dir, NPZ_FILE)
if tf.gfile.Exists(np_filename):
raise ValueError('data_dir already has the processed data file: {}'.format(
np_filename))
if not tf.gfile.Exists(data_dir):
tf.gfile.MkDir(data_dir)
# 2.8 GB to download.
try:
print('Data downloading..')
temp_filename, _ = urllib.request.urlretrieve(input_url)
# Reading and parsing 11 million csv lines takes 2~3 minutes.
print('Data processing.. taking multiple minutes..')
data = pd.read_csv(
temp_filename,
dtype=np.float32,
names=['c%02d' % i for i in range(29)] # label + 28 features.
).as_matrix()
finally:
os.remove(temp_filename)
# Writing to temporary location then copy to the data_dir (0.8 GB).
f = tempfile.NamedTemporaryFile()
np.savez_compressed(f, data=data)
tf.gfile.Copy(f.name, np_filename)
print('Data saved to: {}'.format(np_filename))
def main(unused_argv):
if not tf.gfile.Exists(FLAGS.data_dir):
tf.gfile.MkDir(FLAGS.data_dir)
_download_higgs_data_and_save_npz(FLAGS.data_dir)
if __name__ == '__main__':
FLAGS, unparsed = parse_args()
tf.app.run(argv=[sys.argv[0]] + unparsed)
r"""A script that builds boosted trees over higgs data.
If you haven't, please run data_download.py beforehand to prepare the data.
For some more details on this example, please refer to README.md as well.
Note that the model_dir is cleaned up before starting the training.
Usage:
$ python train_higgs.py --n_trees=100 --max_depth=6 --learning_rate=0.1 \
--model_dir=/tmp/higgs_model
Note that BoostedTreesClassifier is available since Tensorflow 1.8.0.
So you need to install recent enough version of Tensorflow to use this example.
The training data is by default the first million examples out of 11M examples,
and eval data is by default the last million examples.
They are controlled by --train_start, --train_count, --eval_start, --eval_count.
e.g. to train over the first 10 million examples instead of 1 million:
$ python train_higgs.py --n_trees=100 --max_depth=6 --learning_rate=0.1 \
--model_dir=/tmp/higgs_model --train_count=10000000
Training history and metrics can be inspected using tensorboard.
Set --logdir as the --model_dir set by flag when training
(or the default /tmp/higgs_model).
$ tensorboard --logdir=/tmp/higgs_model
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
from absl import app as absl_app
from absl import flags
import numpy as np # pylint: disable=wrong-import-order
import tensorflow as tf # pylint: disable=wrong-import-order
from official.utils.flags import core as flags_core
from official.utils.flags._conventions import help_wrap
NPZ_FILE = 'HIGGS.csv.gz.npz' # numpy compressed file containing 'data' array
def define_train_higgs_flags():
"""Add tree related flags as well as training/eval configuration."""
flags_core.define_base(stop_threshold=False, batch_size=False, num_gpu=False)
flags.adopt_module_key_flags(flags_core)
flags.DEFINE_integer(
name='train_start', default=0,
help=help_wrap('Start index of train examples within the data.'))
flags.DEFINE_integer(
name='train_count', default=1000000,
help=help_wrap('Number of train examples within the data.'))
flags.DEFINE_integer(
name='eval_start', default=10000000,
help=help_wrap('Start index of eval examples within the data.'))
flags.DEFINE_integer(
name='eval_count', default=1000000,
help=help_wrap('Number of eval examples within the data.'))
flags.DEFINE_integer(
'n_trees', default=100, help=help_wrap('Number of trees to build.'))
flags.DEFINE_integer(
'max_depth', default=6, help=help_wrap('Maximum depths of each tree.'))
flags.DEFINE_float(
'learning_rate', default=0.1,
help=help_wrap('Maximum depths of each tree.'))
flags_core.set_defaults(data_dir='/tmp/higgs_data',
model_dir='/tmp/higgs_model')
def read_higgs_data(data_dir, train_start, train_count, eval_start, eval_count):
"""Reads higgs data from csv and returns train and eval data."""
npz_filename = os.path.join(data_dir, NPZ_FILE)
try:
# gfile allows numpy to read data from network data sources as well.
with tf.gfile.Open(npz_filename, 'rb') as npz_file:
with np.load(npz_file) as npz:
data = npz['data']
except Exception as e:
raise RuntimeError(
'Error loading data; use data_download.py to prepare the data:\n{}: {}'
.format(type(e).__name__, e))
return (data[train_start:train_start+train_count],
data[eval_start:eval_start+eval_count])
# This showcases how to make input_fn when the input data is available in the
# form of numpy arrays.
def make_inputs_from_np_arrays(features_np, label_np):
"""Makes and returns input_fn and feature_columns from numpy arrays.
The generated input_fn will return tf.data.Dataset of feature dictionary and a
label, and feature_columns will consist of the list of
tf.feature_column.BucketizedColumn.
Note, for in-memory training, tf.data.Dataset should contain the whole data
as a single tensor. Don't use batch.
Args:
features_np: a numpy ndarray (shape=[batch_size, num_features]) for
float32 features.
label_np: a numpy ndarray (shape=[batch_size, 1]) for labels.
Returns:
input_fn: a function returning a Dataset of feature dict and label.
feature_column: a list of tf.feature_column.BucketizedColumn.
"""
num_features = features_np.shape[1]
features_np_list = np.split(features_np, num_features, axis=1)
# 1-based feature names.
feature_names = ['feature_%02d' % (i + 1) for i in range(num_features)]
# Create source feature_columns and bucketized_columns.
def get_bucket_boundaries(feature):
"""Returns bucket boundaries for feature by percentiles."""
return np.unique(np.percentile(feature, range(0, 100))).tolist()
source_columns = [
tf.feature_column.numeric_column(
feature_name, dtype=tf.float32,
# Although higgs data have no missing values, in general, default
# could be set as 0 or some reasonable value for missing values.
default_value=0.0)
for feature_name in feature_names
]
bucketized_columns = [
tf.feature_column.bucketized_column(
source_columns[i],
boundaries=get_bucket_boundaries(features_np_list[i]))
for i in range(num_features)
]
# Make an input_fn that extracts source features.
def input_fn():
"""Returns features as a dictionary of numpy arrays, and a label."""
features = {
feature_name: tf.constant(features_np_list[i])
for i, feature_name in enumerate(feature_names)
}
return tf.data.Dataset.zip((tf.data.Dataset.from_tensors(features),
tf.data.Dataset.from_tensors(label_np),))
return input_fn, bucketized_columns
def make_eval_inputs_from_np_arrays(features_np, label_np):
"""Makes eval input as streaming batches."""
num_features = features_np.shape[1]
features_np_list = np.split(features_np, num_features, axis=1)
# 1-based feature names.
feature_names = ['feature_%02d' % (i + 1) for i in range(num_features)]
def input_fn():
features = {
feature_name: tf.constant(features_np_list[i])
for i, feature_name in enumerate(feature_names)
}
return tf.data.Dataset.zip(
(tf.data.Dataset.from_tensor_slices(features),
tf.data.Dataset.from_tensor_slices(label_np),)).batch(1000)
return input_fn
def train_boosted_trees(flags_obj):
"""Train boosted_trees estimator on HIGGS data.
Args:
flags_obj: An object containing parsed flag values.
"""
# Clean up the model directory if present.
if tf.gfile.Exists(flags_obj.model_dir):
tf.gfile.DeleteRecursively(flags_obj.model_dir)
print('## data loading..')
train_data, eval_data = read_higgs_data(
flags_obj.data_dir, flags_obj.train_start, flags_obj.train_count,
flags_obj.eval_start, flags_obj.eval_count)
print('## data loaded; train: {}{}, eval: {}{}'.format(
train_data.dtype, train_data.shape, eval_data.dtype, eval_data.shape))
# data consists of one label column and 28 feature columns following.
train_input_fn, feature_columns = make_inputs_from_np_arrays(
features_np=train_data[:, 1:], label_np=train_data[:, 0:1])
eval_input_fn = make_eval_inputs_from_np_arrays(
features_np=eval_data[:, 1:], label_np=eval_data[:, 0:1])
print('## features prepared. training starts..')
# Though BoostedTreesClassifier is under tf.estimator, faster in-memory
# training is yet provided as a contrib library.
classifier = tf.contrib.estimator.boosted_trees_classifier_train_in_memory(
train_input_fn,
feature_columns,
model_dir=flags_obj.model_dir or None,
n_trees=flags_obj.n_trees,
max_depth=flags_obj.max_depth,
learning_rate=flags_obj.learning_rate)
# Evaluation.
eval_result = classifier.evaluate(eval_input_fn)
# Exporting the savedmodel.
if flags_obj.export_dir is not None:
feature_spec = tf.estimator.export.build_parsing_serving_input_receiver_fn(
tf.feature_column.make_parse_example_spec(feature_columns))
classifier.export_savedmodel(flags_obj.export_dir, feature_spec)
def main(_):
train_boosted_trees(flags.FLAGS)
if __name__ == '__main__':
# Training progress and eval results are shown as logging.INFO; so enables it.
tf.logging.set_verbosity(tf.logging.INFO)
define_train_higgs_flags()
absl_app.run(main)
1.000000000000000000e+00,8.692932128906250000e-01,-6.350818276405334473e-01,2.256902605295181274e-01,3.274700641632080078e-01,-6.899932026863098145e-01,7.542022466659545898e-01,-2.485731393098831177e-01,-1.092063903808593750e+00,0.000000000000000000e+00,1.374992132186889648e+00,-6.536741852760314941e-01,9.303491115570068359e-01,1.107436060905456543e+00,1.138904333114624023e+00,-1.578198313713073730e+00,-1.046985387802124023e+00,0.000000000000000000e+00,6.579295396804809570e-01,-1.045456994324922562e-02,-4.576716944575309753e-02,3.101961374282836914e+00,1.353760004043579102e+00,9.795631170272827148e-01,9.780761599540710449e-01,9.200048446655273438e-01,7.216574549674987793e-01,9.887509346008300781e-01,8.766783475875854492e-01
1.000000000000000000e+00,9.075421094894409180e-01,3.291472792625427246e-01,3.594118654727935791e-01,1.497969865798950195e+00,-3.130095303058624268e-01,1.095530629158020020e+00,-5.575249195098876953e-01,-1.588229775428771973e+00,2.173076152801513672e+00,8.125811815261840820e-01,-2.136419266462326050e-01,1.271014571189880371e+00,2.214872121810913086e+00,4.999939501285552979e-01,-1.261431813240051270e+00,7.321561574935913086e-01,0.000000000000000000e+00,3.987008929252624512e-01,-1.138930082321166992e+00,-8.191101951524615288e-04,0.000000000000000000e+00,3.022198975086212158e-01,8.330481648445129395e-01,9.856996536254882812e-01,9.780983924865722656e-01,7.797321677207946777e-01,9.923557639122009277e-01,7.983425855636596680e-01
1.000000000000000000e+00,7.988347411155700684e-01,1.470638751983642578e+00,-1.635974764823913574e+00,4.537731707096099854e-01,4.256291687488555908e-01,1.104874610900878906e+00,1.282322287559509277e+00,1.381664276123046875e+00,0.000000000000000000e+00,8.517372012138366699e-01,1.540658950805664062e+00,-8.196895122528076172e-01,2.214872121810913086e+00,9.934899210929870605e-01,3.560801148414611816e-01,-2.087775468826293945e-01,2.548224449157714844e+00,1.256954550743103027e+00,1.128847599029541016e+00,9.004608392715454102e-01,0.000000000000000000e+00,9.097532629966735840e-01,1.108330488204956055e+00,9.856922030448913574e-01,9.513312578201293945e-01,8.032515048980712891e-01,8.659244179725646973e-01,7.801175713539123535e-01
0.000000000000000000e+00,1.344384789466857910e+00,-8.766260147094726562e-01,9.359127283096313477e-01,1.992050051689147949e+00,8.824543952941894531e-01,1.786065936088562012e+00,-1.646777749061584473e+00,-9.423825144767761230e-01,0.000000000000000000e+00,2.423264741897583008e+00,-6.760157942771911621e-01,7.361586689949035645e-01,2.214872121810913086e+00,1.298719763755798340e+00,-1.430738091468811035e+00,-3.646581768989562988e-01,0.000000000000000000e+00,7.453126907348632812e-01,-6.783788204193115234e-01,-1.360356330871582031e+00,0.000000000000000000e+00,9.466524720191955566e-01,1.028703689575195312e+00,9.986560940742492676e-01,7.282806038856506348e-01,8.692002296447753906e-01,1.026736497879028320e+00,9.579039812088012695e-01
1.000000000000000000e+00,1.105008959770202637e+00,3.213555514812469482e-01,1.522401213645935059e+00,8.828076124191284180e-01,-1.205349326133728027e+00,6.814661026000976562e-01,-1.070463895797729492e+00,-9.218706488609313965e-01,0.000000000000000000e+00,8.008721470832824707e-01,1.020974040031433105e+00,9.714065194129943848e-01,2.214872121810913086e+00,5.967612862586975098e-01,-3.502728641033172607e-01,6.311942934989929199e-01,0.000000000000000000e+00,4.799988865852355957e-01,-3.735655248165130615e-01,1.130406111478805542e-01,0.000000000000000000e+00,7.558564543724060059e-01,1.361057043075561523e+00,9.866096973419189453e-01,8.380846381187438965e-01,1.133295178413391113e+00,8.722448945045471191e-01,8.084865212440490723e-01
0.000000000000000000e+00,1.595839262008666992e+00,-6.078106760978698730e-01,7.074915803968906403e-03,1.818449616432189941e+00,-1.119059920310974121e-01,8.475499153137207031e-01,-5.664370059967041016e-01,1.581239342689514160e+00,2.173076152801513672e+00,7.554209828376770020e-01,6.431096196174621582e-01,1.426366806030273438e+00,0.000000000000000000e+00,9.216607809066772461e-01,-1.190432429313659668e+00,-1.615589022636413574e+00,0.000000000000000000e+00,6.511141061782836914e-01,-6.542269587516784668e-01,-1.274344921112060547e+00,3.101961374282836914e+00,8.237605690956115723e-01,9.381914138793945312e-01,9.717581868171691895e-01,7.891763448715209961e-01,4.305532872676849365e-01,9.613569378852844238e-01,9.578179121017456055e-01
1.000000000000000000e+00,4.093913435935974121e-01,-1.884683609008789062e+00,-1.027292013168334961e+00,1.672451734542846680e+00,-1.604598283767700195e+00,1.338014960289001465e+00,5.542744323611259460e-02,1.346588134765625000e-02,2.173076152801513672e+00,5.097832679748535156e-01,-1.038338065147399902e+00,7.078623175621032715e-01,0.000000000000000000e+00,7.469175457954406738e-01,-3.584651052951812744e-01,-1.646654248237609863e+00,0.000000000000000000e+00,3.670579791069030762e-01,6.949646025896072388e-02,1.377130270004272461e+00,3.101961374282836914e+00,8.694183826446533203e-01,1.222082972526550293e+00,1.000627398490905762e+00,5.450449585914611816e-01,6.986525058746337891e-01,9.773144721984863281e-01,8.287860751152038574e-01
1.000000000000000000e+00,9.338953495025634766e-01,6.291297078132629395e-01,5.275348424911499023e-01,2.380327433347702026e-01,-9.665691256523132324e-01,5.478111505508422852e-01,-5.943922698497772217e-02,-1.706866145133972168e+00,2.173076152801513672e+00,9.410027265548706055e-01,-2.653732776641845703e+00,-1.572199910879135132e-01,0.000000000000000000e+00,1.030370354652404785e+00,-1.755051016807556152e-01,5.230209231376647949e-01,2.548224449157714844e+00,1.373546600341796875e+00,1.291248083114624023e+00,-1.467454433441162109e+00,0.000000000000000000e+00,9.018372893333435059e-01,1.083671212196350098e+00,9.796960949897766113e-01,7.833003997802734375e-01,8.491951823234558105e-01,8.943563103675842285e-01,7.748793959617614746e-01
1.000000000000000000e+00,1.405143737792968750e+00,5.366026163101196289e-01,6.895543336868286133e-01,1.179567337036132812e+00,-1.100611537694931030e-01,3.202404975891113281e+00,-1.526960015296936035e+00,-1.576033473014831543e+00,0.000000000000000000e+00,2.931536912918090820e+00,5.673424601554870605e-01,-1.300333440303802490e-01,2.214872121810913086e+00,1.787122726440429688e+00,8.994985818862915039e-01,5.851513147354125977e-01,2.548224449157714844e+00,4.018652141094207764e-01,-1.512016952037811279e-01,1.163489103317260742e+00,0.000000000000000000e+00,1.667070508003234863e+00,4.039272785186767578e+00,1.175828456878662109e+00,1.045351743698120117e+00,1.542971968650817871e+00,3.534826755523681641e+00,2.740753889083862305e+00
1.000000000000000000e+00,1.176565527915954590e+00,1.041605025529861450e-01,1.397002458572387695e+00,4.797213077545166016e-01,2.655133903026580811e-01,1.135563015937805176e+00,1.534830927848815918e+00,-2.532912194728851318e-01,0.000000000000000000e+00,1.027246594429016113e+00,5.343157649040222168e-01,1.180022358894348145e+00,0.000000000000000000e+00,2.405661106109619141e+00,8.755676448345184326e-02,-9.765340685844421387e-01,2.548224449157714844e+00,1.250382542610168457e+00,2.685412168502807617e-01,5.303344726562500000e-01,0.000000000000000000e+00,8.331748843193054199e-01,7.739681005477905273e-01,9.857499599456787109e-01,1.103696346282958984e+00,8.491398692131042480e-01,9.371039867401123047e-01,8.123638033866882324e-01
1.000000000000000000e+00,9.459739923477172852e-01,1.111244320869445801e+00,1.218337059020996094e+00,9.076390862464904785e-01,8.215369582176208496e-01,1.153243303298950195e+00,-3.654202818870544434e-01,-1.566054821014404297e+00,0.000000000000000000e+00,7.447192072868347168e-01,7.208195328712463379e-01,-3.758229315280914307e-01,2.214872121810913086e+00,6.088791489601135254e-01,3.078369498252868652e-01,-1.281638383865356445e+00,0.000000000000000000e+00,1.597967982292175293e+00,-4.510180354118347168e-01,6.365344673395156860e-02,3.101961374282836914e+00,8.290241360664367676e-01,9.806482791900634766e-01,9.943597912788391113e-01,9.082478284835815430e-01,7.758789062500000000e-01,7.833113670349121094e-01,7.251217961311340332e-01
0.000000000000000000e+00,7.393567562103271484e-01,-1.782904267311096191e-01,8.299342393875122070e-01,5.045390725135803223e-01,-1.302167475223541260e-01,9.610513448715209961e-01,-3.555179834365844727e-01,-1.717399358749389648e+00,2.173076152801513672e+00,6.209560632705688477e-01,-4.817410409450531006e-01,-1.199193239212036133e+00,0.000000000000000000e+00,9.826014041900634766e-01,8.118502795696258545e-02,-2.903236448764801025e-01,0.000000000000000000e+00,1.064662933349609375e+00,7.740649580955505371e-01,3.988203406333923340e-01,3.101961374282836914e+00,9.445360302925109863e-01,1.026260614395141602e+00,9.821967482566833496e-01,5.421146750450134277e-01,1.250978946685791016e+00,8.300446271896362305e-01,7.613079547882080078e-01
1.000000000000000000e+00,1.384097695350646973e+00,1.168220937252044678e-01,-1.179878950119018555e+00,7.629125714302062988e-01,-7.978226989507675171e-02,1.019863128662109375e+00,8.773182630538940430e-01,1.276887178421020508e+00,2.173076152801513672e+00,3.312520980834960938e-01,1.409523487091064453e+00,-1.474388837814331055e+00,0.000000000000000000e+00,1.282738208770751953e+00,7.374743819236755371e-01,-2.254196107387542725e-01,0.000000000000000000e+00,1.559753060340881348e+00,8.465205430984497070e-01,5.048085451126098633e-01,3.101961374282836914e+00,9.593246579170227051e-01,8.073760271072387695e-01,1.191813588142395020e+00,1.221210360527038574e+00,8.611412644386291504e-01,9.293408989906311035e-01,8.383023738861083984e-01
1.000000000000000000e+00,1.383548736572265625e+00,8.891792893409729004e-01,6.185320615768432617e-01,1.081547021865844727e+00,3.446055650711059570e-01,9.563793540000915527e-01,8.545429706573486328e-01,-1.129207015037536621e+00,2.173076152801513672e+00,5.456657409667968750e-01,-3.078651726245880127e-01,-6.232798099517822266e-01,2.214872121810913086e+00,3.482571244239807129e-01,1.024202585220336914e+00,1.840776652097702026e-01,0.000000000000000000e+00,7.813369035720825195e-01,-1.636125564575195312e+00,1.144067287445068359e+00,0.000000000000000000e+00,5.222384929656982422e-01,7.376385331153869629e-01,9.861995577812194824e-01,1.349615693092346191e+00,8.127878904342651367e-01,9.534064531326293945e-01,7.797226309776306152e-01
1.000000000000000000e+00,1.343652725219726562e+00,8.385329246520996094e-01,-1.061138510704040527e+00,2.472015142440795898e+00,-5.726317167282104492e-01,1.512709975242614746e+00,1.143690109252929688e+00,8.555619716644287109e-01,0.000000000000000000e+00,8.842203021049499512e-01,1.474605560302734375e+00,-1.360648751258850098e+00,1.107436060905456543e+00,1.587265610694885254e+00,2.234833478927612305e+00,7.756848633289337158e-02,0.000000000000000000e+00,1.609408140182495117e+00,2.396404743194580078e+00,7.572935223579406738e-01,0.000000000000000000e+00,9.340201020240783691e-01,8.447072505950927734e-01,1.077844023704528809e+00,1.400183677673339844e+00,9.477745294570922852e-01,1.007614254951477051e+00,9.010174870491027832e-01
0.000000000000000000e+00,5.470141768455505371e-01,-3.497089445590972900e-01,-6.466571688652038574e-01,2.040462255477905273e+00,2.764569818973541260e-01,5.446965098381042480e-01,8.386992812156677246e-01,1.728703141212463379e+00,0.000000000000000000e+00,6.528096199035644531e-01,1.471691370010375977e+00,1.243273019790649414e+00,0.000000000000000000e+00,7.857298851013183594e-01,-4.442929103970527649e-02,-1.019803404808044434e+00,2.548224449157714844e+00,4.191471040248870850e-01,-6.292421817779541016e-01,1.570794582366943359e+00,3.101961374282836914e+00,6.894335746765136719e-01,8.672295808792114258e-01,1.082487821578979492e+00,6.641419529914855957e-01,3.541145622730255127e-01,5.799450278282165527e-01,8.172734379768371582e-01
1.000000000000000000e+00,1.484203696250915527e+00,1.699521422386169434e+00,-1.059473991394042969e+00,2.700195550918579102e+00,-1.055963873863220215e+00,2.409452915191650391e+00,4.574607908725738525e-01,3.449823260307312012e-01,0.000000000000000000e+00,1.414903521537780762e+00,1.114225864410400391e+00,-1.448866605758666992e+00,0.000000000000000000e+00,1.012983918190002441e+00,-2.056988954544067383e+00,1.131010890007019043e+00,0.000000000000000000e+00,9.054746031761169434e-01,2.182368993759155273e+00,1.043073177337646484e+00,0.000000000000000000e+00,1.653626322746276855e+00,9.935762286186218262e-01,9.833217859268188477e-01,7.413797974586486816e-01,1.633816361427307129e-01,5.923243165016174316e-01,7.451378703117370605e-01
0.000000000000000000e+00,1.057975649833679199e+00,-1.607590019702911377e-01,-1.949972510337829590e-01,2.705023050308227539e+00,-7.514767050743103027e-01,1.909918904304504395e+00,-1.031844973564147949e+00,8.649863600730895996e-01,0.000000000000000000e+00,1.300834894180297852e+00,1.467376798391342163e-01,-1.118742942810058594e+00,1.107436060905456543e+00,9.669710993766784668e-01,-3.666573464870452881e-01,1.108266711235046387e+00,0.000000000000000000e+00,5.547249317169189453e-01,-7.141901850700378418e-01,1.505314946174621582e+00,3.101961374282836914e+00,9.544943571090698242e-01,6.510385870933532715e-01,1.124949693679809570e+00,8.940010070800781250e-01,6.721734404563903809e-01,1.182358264923095703e+00,1.316304087638854980e+00
0.000000000000000000e+00,6.753035783767700195e-01,1.120983958244323730e+00,-2.804459035396575928e-01,1.539554953575134277e+00,7.345175743103027344e-01,6.146844029426574707e-01,-5.070231556892395020e-01,7.945806980133056641e-01,2.173076152801513672e+00,2.188202738761901855e-01,-1.894118309020996094e+00,-5.805578827857971191e-01,0.000000000000000000e+00,1.245682120323181152e+00,-3.475421071052551270e-01,-8.561564683914184570e-01,2.548224449157714844e+00,7.531017661094665527e-01,-1.145592689514160156e+00,-1.374783992767333984e+00,0.000000000000000000e+00,9.069401025772094727e-01,8.983390927314758301e-01,1.119651079177856445e+00,1.269073486328125000e+00,1.088765859603881836e+00,1.015413045883178711e+00,9.146358966827392578e-01
1.000000000000000000e+00,6.427279114723205566e-01,-1.429840326309204102e+00,1.519071936607360840e+00,9.409985542297363281e-01,8.872274160385131836e-01,1.615126848220825195e+00,-1.336835741996765137e+00,-2.665962278842926025e-01,1.086538076400756836e+00,1.667088270187377930e+00,6.557375192642211914e-01,-1.588128924369812012e+00,0.000000000000000000e+00,8.282302021980285645e-01,1.836144566535949707e+00,4.081907570362091064e-01,0.000000000000000000e+00,1.708718180656433105e+00,-3.469151556491851807e-01,-1.182784557342529297e+00,3.101961374282836914e+00,9.210902452468872070e-01,1.373361706733703613e+00,9.849172830581665039e-01,1.422878146171569824e+00,1.546551108360290527e+00,1.782585501670837402e+00,1.438173770904541016e+00
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
import numpy as np
import pandas as pd
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.testing import integration
from official.boosted_trees import train_higgs
tf.logging.set_verbosity(tf.logging.ERROR)
TEST_CSV = os.path.join(os.path.dirname(__file__), 'train_higgs_test.csv')
class BaseTest(tf.test.TestCase):
"""Tests for Wide Deep model."""
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(BaseTest, cls).setUpClass()
train_higgs.define_train_higgs_flags()
def setUp(self):
# Create temporary CSV file
self.data_dir = self.get_temp_dir()
data = pd.read_csv(
TEST_CSV, dtype=np.float32, names=['c%02d' % i for i in range(29)]
).as_matrix()
self.input_npz = os.path.join(self.data_dir, train_higgs.NPZ_FILE)
# numpy.savez doesn't take gfile.Gfile, so need to write down and copy.
tmpfile = tempfile.NamedTemporaryFile()
np.savez_compressed(tmpfile, data=data)
tf.gfile.Copy(tmpfile.name, self.input_npz)
def test_read_higgs_data(self):
"""Tests read_higgs_data() function."""
# Error when a wrong data_dir is given.
with self.assertRaisesRegexp(RuntimeError, 'Error loading data.*'):
train_data, eval_data = train_higgs.read_higgs_data(
self.data_dir + 'non-existing-path',
train_start=0, train_count=15, eval_start=15, eval_count=5)
# Loading fine with the correct data_dir.
train_data, eval_data = train_higgs.read_higgs_data(
self.data_dir,
train_start=0, train_count=15, eval_start=15, eval_count=5)
self.assertEqual((15, 29), train_data.shape)
self.assertEqual((5, 29), eval_data.shape)
def test_make_inputs_from_np_arrays(self):
"""Tests make_inputs_from_np_arrays() function."""
train_data, _ = train_higgs.read_higgs_data(
self.data_dir,
train_start=0, train_count=15, eval_start=15, eval_count=5)
input_fn, feature_columns = train_higgs.make_inputs_from_np_arrays(
features_np=train_data[:, 1:], label_np=train_data[:, 0:1])
# Check feature columns.
self.assertEqual(28, len(feature_columns))
bucketized_column_type = type(
tf.feature_column.bucketized_column(
tf.feature_column.numeric_column('feature_01'),
boundaries=[0, 1, 2])) # dummy boundaries.
for feature_column in feature_columns:
self.assertIsInstance(feature_column, bucketized_column_type)
# At least 2 boundaries.
self.assertGreaterEqual(len(feature_column.boundaries), 2)
feature_names = ['feature_%02d' % (i+1) for i in range(28)]
# Tests that the source column names of the bucketized columns match.
self.assertAllEqual(feature_names,
[col.source_column.name for col in feature_columns])
# Check features.
features, labels = input_fn().make_one_shot_iterator().get_next()
with tf.Session() as sess:
features, labels = sess.run((features, labels))
self.assertIsInstance(features, dict)
self.assertAllEqual(feature_names, sorted(features.keys()))
self.assertAllEqual([[15, 1]] * 28,
[features[name].shape for name in feature_names])
# Validate actual values of some features.
self.assertAllClose(
[0.869293, 0.907542, 0.798834, 1.344384, 1.105009, 1.595839,
0.409391, 0.933895, 1.405143, 1.176565, 0.945974, 0.739356,
1.384097, 1.383548, 1.343652],
np.squeeze(features[feature_names[0]], 1))
self.assertAllClose(
[-0.653674, -0.213641, 1.540659, -0.676015, 1.020974, 0.643109,
-1.038338, -2.653732, 0.567342, 0.534315, 0.720819, -0.481741,
1.409523, -0.307865, 1.474605],
np.squeeze(features[feature_names[10]], 1))
def test_end_to_end(self):
"""Tests end-to-end running."""
model_dir = os.path.join(self.get_temp_dir(), 'model')
integration.run_synthetic(
main=train_higgs.main, tmp_root=self.get_temp_dir(), extra_flags=[
'--data_dir', self.data_dir,
'--model_dir', model_dir,
'--n_trees', '5',
'--train_start', '0',
'--train_count', '12',
'--eval_start', '12',
'--eval_count', '8',
],
synth=False, max_train=None)
self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, 'checkpoint')))
def test_end_to_end_with_export(self):
"""Tests end-to-end running."""
model_dir = os.path.join(self.get_temp_dir(), 'model')
export_dir = os.path.join(self.get_temp_dir(), 'export')
integration.run_synthetic(
main=train_higgs.main, tmp_root=self.get_temp_dir(), extra_flags=[
'--data_dir', self.data_dir,
'--model_dir', model_dir,
'--export_dir', export_dir,
'--n_trees', '5',
'--train_start', '0',
'--train_count', '12',
'--eval_start', '12',
'--eval_count', '8',
],
synth=False, max_train=None)
self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, 'checkpoint')))
self.assertTrue(tf.gfile.Exists(os.path.join(export_dir)))
if __name__ == '__main__':
tf.test.main()
numpy
pandas
psutil>=5.4.3 psutil>=5.4.3
py-cpuinfo>=3.3.0 py-cpuinfo>=3.3.0
google-cloud-bigquery>=0.31.0 google-cloud-bigquery>=0.31.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment