"docs/source/en/vscode:/vscode.git/clone" did not exist on "30c977d1f5300bffdf14b92d06c90293e789b587"
Commit 9dafea91 authored by sunxx1's avatar sunxx1
Browse files

Merge branch 'qianyj_tf' into 'main'

update tf code

See merge request dcutoolkit/deeplearing/dlexamples_new!35
parents 92a2ca36 a4146470
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for benchmark_uploader."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import tempfile
import unittest
from mock import MagicMock
from mock import patch
import tensorflow as tf # pylint: disable=g-bad-import-order
try:
from google.cloud import bigquery
from official.benchmark import benchmark_uploader
except ImportError:
bigquery = None
benchmark_uploader = None
@unittest.skipIf(bigquery is None, "Bigquery dependency is not installed.")
class BigQueryUploaderTest(tf.test.TestCase):
@patch.object(bigquery, "Client")
def setUp(self, mock_bigquery):
self.mock_client = mock_bigquery.return_value
self.mock_dataset = MagicMock(name="dataset")
self.mock_table = MagicMock(name="table")
self.mock_client.dataset.return_value = self.mock_dataset
self.mock_dataset.table.return_value = self.mock_table
self.mock_client.insert_rows_json.return_value = []
self.benchmark_uploader = benchmark_uploader.BigQueryUploader()
self.benchmark_uploader._bq_client = self.mock_client
self.log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
with open(os.path.join(self.log_dir, "metric.log"), "a") as f:
json.dump({"name": "accuracy", "value": 1.0}, f)
f.write("\n")
json.dump({"name": "loss", "value": 0.5}, f)
f.write("\n")
with open(os.path.join(self.log_dir, "run.log"), "w") as f:
json.dump({"model_name": "value"}, f)
def tearDown(self):
tf.gfile.DeleteRecursively(self.get_temp_dir())
def test_upload_benchmark_run_json(self):
self.benchmark_uploader.upload_benchmark_run_json(
"dataset", "table", "run_id", {"model_name": "value"})
self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, [{"model_name": "value", "model_id": "run_id"}])
def test_upload_benchmark_metric_json(self):
metric_json_list = [
{"name": "accuracy", "value": 1.0},
{"name": "loss", "value": 0.5}
]
expected_params = [
{"run_id": "run_id", "name": "accuracy", "value": 1.0},
{"run_id": "run_id", "name": "loss", "value": 0.5}
]
self.benchmark_uploader.upload_benchmark_metric_json(
"dataset", "table", "run_id", metric_json_list)
self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, expected_params)
def test_upload_benchmark_run_file(self):
self.benchmark_uploader.upload_benchmark_run_file(
"dataset", "table", "run_id", os.path.join(self.log_dir, "run.log"))
self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, [{"model_name": "value", "model_id": "run_id"}])
def test_upload_metric_file(self):
self.benchmark_uploader.upload_metric_file(
"dataset", "table", "run_id",
os.path.join(self.log_dir, "metric.log"))
expected_params = [
{"run_id": "run_id", "name": "accuracy", "value": 1.0},
{"run_id": "run_id", "name": "loss", "value": 0.5}
]
self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, expected_params)
def test_insert_run_status(self):
self.benchmark_uploader.insert_run_status(
"dataset", "table", "run_id", "status")
expected_query = ("INSERT dataset.table "
"(run_id, status) "
"VALUES('run_id', 'status')")
self.mock_client.query.assert_called_once_with(query=expected_query)
def test_update_run_status(self):
self.benchmark_uploader.update_run_status(
"dataset", "table", "run_id", "status")
expected_query = ("UPDATE dataset.table "
"SET status = 'status' "
"WHERE run_id = 'run_id'")
self.mock_client.query.assert_called_once_with(query=expected_query)
if __name__ == "__main__":
tf.test.main()
[
{
"description": "The ID of the benchmark run, where this metric should tie to.",
"mode": "REQUIRED",
"name": "run_id",
"type": "STRING"
},
{
"description": "The name of the metric, which should be descriptive. E.g. training_loss, accuracy.",
"mode": "REQUIRED",
"name": "name",
"type": "STRING"
},
{
"description": "The unit of the metric. E.g. MB per sec.",
"mode": "NULLABLE",
"name": "unit",
"type": "STRING"
},
{
"description": "The value of the metric.",
"mode": "NULLABLE",
"name": "value",
"type": "FLOAT"
},
{
"description": "The timestamp when the metric is recorded.",
"mode": "REQUIRED",
"name": "timestamp",
"type": "TIMESTAMP"
},
{
"description": "The global step when this metric is recorded.",
"mode": "NULLABLE",
"name": "global_step",
"type": "INTEGER"
},
{
"description": "Free format metadata for the extra information about the metric.",
"mode": "REPEATED",
"name": "extras",
"type": "RECORD",
"fields": [
{
"mode": "NULLABLE",
"name": "name",
"type": "STRING"
},
{
"mode": "NULLABLE",
"name": "value",
"type": "STRING"
}
]
}
]
[
{
"description": "The UUID of the run for the benchmark.",
"mode": "REQUIRED",
"name": "model_id",
"type": "STRING"
},
{
"description": "The name of the model, E.g ResNet50, LeNet-5 etc.",
"mode": "REQUIRED",
"name": "model_name",
"type": "STRING"
},
{
"description": "The date when the test of the model is started",
"mode": "REQUIRED",
"name": "run_date",
"type": "TIMESTAMP"
},
{
"description": "The unique name for a test by the combination of key parameters, eg batch size, num of GPU, etc. It is hardware independent.",
"mode": "NULLABLE",
"name": "test_id",
"type": "STRING"
},
{
"description": "The tensorflow version information.",
"fields": [
{
"description": "Version of the tensorflow. E.g. 1.7.0-rc0",
"mode": "REQUIRED",
"name": "version",
"type": "STRING"
},
{
"description": "Git Hash of the tensorflow",
"mode": "NULLABLE",
"name": "git_hash",
"type": "STRING"
},
{
"description": "The channel of the tensorflow binary, eg, nightly, RC, final, custom.",
"mode": "NULLABLE",
"name": "channel",
"type": "STRING"
},
{
"description": "Identify anything special about the build, eg CUDA 10, NCCL, MKL, etc.",
"mode": "NULLABLE",
"name": "build_type",
"type": "STRING"
}
],
"mode": "REQUIRED",
"name": "tensorflow_version",
"type": "RECORD"
},
{
"description": "The arbitrary attribute of the model.",
"fields": [
{
"description": "The name of the attribute.",
"mode": "REQUIRED",
"name": "name",
"type": "STRING"
},
{
"description": "The value of the attribute.",
"mode": "NULLABLE",
"name": "value",
"type": "STRING"
}
],
"mode": "REPEATED",
"name": "attribute",
"type": "RECORD"
},
{
"description": "Environment variables when the benchmark run is executed.",
"fields": [
{
"description": "The name of the variable.",
"mode": "REQUIRED",
"name": "name",
"type": "STRING"
},
{
"description": "The value of the variable.",
"mode": "NULLABLE",
"name": "value",
"type": "STRING"
}
],
"mode": "REPEATED",
"name": "environment_variable",
"type": "RECORD"
},
{
"description": "TF Environment variables when the benchmark run is executed.",
"fields": [
{
"description": "The name of the variable.",
"mode": "REQUIRED",
"name": "name",
"type": "STRING"
},
{
"description": "The value of the variable.",
"mode": "NULLABLE",
"name": "value",
"type": "STRING"
}
],
"mode": "REPEATED",
"name": "tensorflow_environment_variables",
"type": "RECORD"
},
{
"description": "The list of parameters run with the model. It could contain hyperparameters or others.",
"fields": [
{
"description": "The name of the parameter.",
"mode": "REQUIRED",
"name": "name",
"type": "STRING"
},
{
"description": "The string value of the parameter.",
"mode": "NULLABLE",
"name": "string_value",
"type": "STRING"
},
{
"description": "The bool value of the parameter.",
"mode": "NULLABLE",
"name": "bool_value",
"type": "STRING"
},
{
"description": "The int/long value of the parameter.",
"mode": "NULLABLE",
"name": "long_value",
"type": "INTEGER"
},
{
"description": "The double/float value of parameter.",
"mode": "NULLABLE",
"name": "float_value",
"type": "FLOAT"
}
],
"mode": "REPEATED",
"name": "run_parameters",
"type": "RECORD"
},
{
"description": "The dataset that run with the benchmark.",
"mode": "NULLABLE",
"name": "dataset",
"type": "RECORD",
"fields": [
{
"description": "The name of the dataset that the model is trained/validated with. E.g ImageNet, mnist.",
"mode": "REQUIRED",
"name": "name",
"type": "STRING"
},
{
"description": "The arbitrary attribute of the dataset.",
"fields": [
{
"description": "The name of the attribute.",
"mode": "REQUIRED",
"name": "name",
"type": "STRING"
},
{
"description": "The value of the attribute.",
"mode": "NULLABLE",
"name": "value",
"type": "STRING"
}
],
"mode": "REPEATED",
"name": "attribute",
"type": "RECORD"
}
]
},
{
"description": "Used to differentiate from AWS, GCE or DGX-1 at a high level",
"mode": "NULLABLE",
"name": "test_environment",
"type": "STRING"
},
{
"description": "The machine configuration of the benchmark run.",
"mode": "NULLABLE",
"name": "machine_config",
"type": "RECORD",
"fields": [
{
"description": "The platform information of the benchmark run.",
"mode": "NULLABLE",
"name": "platform_info",
"type": "RECORD",
"fields": [
{
"description": "Eg: 64bit.",
"mode": "NULLABLE",
"name": "bits",
"type": "STRING"
},
{
"description": "Eg: ELF.",
"mode": "NULLABLE",
"name": "linkage",
"type": "STRING"
},
{
"description": "Eg: i386.",
"mode": "NULLABLE",
"name": "machine",
"type": "STRING"
},
{
"description": "Eg: 3.13.0-76-generic.",
"mode": "NULLABLE",
"name": "release",
"type": "STRING"
},
{
"description": "Eg: Linux.",
"mode": "NULLABLE",
"name": "system",
"type": "STRING"
},
{
"description": "Eg: #120-Ubuntu SMP Mon Jan 18 15:59:10 UTC 2016.",
"mode": "NULLABLE",
"name": "version",
"type": "STRING"
}
]
},
{
"description": "The CPU information of the benchmark run.",
"mode": "NULLABLE",
"name": "cpu_info",
"type": "RECORD",
"fields": [
{
"mode": "NULLABLE",
"name": "num_cores",
"type": "INTEGER"
},
{
"mode": "NULLABLE",
"name": "num_cores_allowed",
"type": "INTEGER"
},
{
"description" : "How fast are those CPUs.",
"mode": "NULLABLE",
"name": "mhz_per_cpu",
"type": "FLOAT"
},
{
"description" : "Additional CPU info, Eg: Intel Ivybridge with HyperThreading (24 cores).",
"mode": "NULLABLE",
"name": "cpu_info",
"type": "STRING"
},
{
"description" : "What kind of cpu scaling is enabled on the host. Eg performance, ondemand, conservative, mixed.",
"mode": "NULLABLE",
"name": "cpu_governor",
"type": "STRING"
},
{
"description": "Cache size of the CPUs.",
"mode": "NULLABLE",
"name": "cache_size",
"type": "RECORD",
"fields": [
{
"mode": "NULLABLE",
"name": "level",
"type": "STRING"
},
{
"mode": "NULLABLE",
"name": "size",
"type": "INTEGER"
}
]
}
]
},
{
"mode": "NULLABLE",
"name": "gpu_info",
"type": "RECORD",
"fields": [
{
"mode": "NULLABLE",
"name": "count",
"type": "INTEGER"
},
{
"mode": "NULLABLE",
"name": "model",
"type": "STRING"
},
{
"mode": "NULLABLE",
"name": "cuda_version",
"type": "STRING"
}
]
},
{
"description": "The cloud instance inforation if the benchmark run is executed on cloud",
"mode": "NULLABLE",
"name": "cloud_info",
"type": "RECORD",
"fields": [
{
"description": "The instance type, E.g. n1-standard-4.",
"mode": "NULLABLE",
"name": "instance_type",
"type": "STRING"
},
{
"description": "The arbitrary attribute of the cloud info.",
"fields": [
{
"description": "The name of the attribute.",
"mode": "REQUIRED",
"name": "name",
"type": "STRING"
},
{
"description": "The value of the attribute.",
"mode": "NULLABLE",
"name": "value",
"type": "STRING"
}
],
"mode": "REPEATED",
"name": "attribute",
"type": "RECORD"
}
]
},
{
"mode": "NULLABLE",
"name": "memory_total",
"type": "INTEGER"
},
{
"mode": "NULLABLE",
"name": "memory_available",
"type": "STRING"
}
]
}
]
[
{
"description": "The UUID of the run for the benchmark.",
"mode": "REQUIRED",
"name": "run_id",
"type": "STRING"
},
{
"description": "The status of the run for the benchmark. Eg, running, failed, success",
"mode": "REQUIRED",
"name": "status",
"type": "STRING"
}
]
\ No newline at end of file
# Classifying Higgs boson processes in the HIGGS Data Set
## Overview
The [HIGGS Data Set](https://archive.ics.uci.edu/ml/datasets/HIGGS) contains 11 million samples with 28 features, and is for the classification problem to distinguish between a signal process which produces Higgs bosons and a background process which does not.
We use Gradient Boosted Trees algorithm to distinguish the two classes.
---
The code sample uses the high level `tf.estimator.Estimator` and `tf.data.Dataset`. These APIs are great for fast iteration and quickly adapting models to your own datasets without major code overhauls. It allows you to move from single-worker training to distributed training, and makes it easy to export model binaries for prediction. Here, for further simplicity and faster execution, we use a utility function `tf.contrib.estimator.boosted_trees_classifier_train_in_memory`. This utility function is especially effective when the input is provided as in-memory data sets like numpy arrays.
An input function for the `Estimator` typically uses `tf.data.Dataset` API, which can handle various data control like streaming, batching, transform and shuffling. However `boosted_trees_classifier_train_in_memory()` utility function requires that the entire data is provided as a single batch (i.e. without using `batch()` API). Thus in this practice, simply `Dataset.from_tensors()` is used to convert numpy arrays into structured tensors, and `Dataset.zip()` is used to put features and label together.
For further references of `Dataset`, [Read more here](https://www.tensorflow.org/guide/datasets).
## Running the code
First make sure you've [added the models folder to your Python path](/official/#running-the-models); otherwise you may encounter an error like `ImportError: No module named official.boosted_trees`.
### Setup
The [HIGGS Data Set](https://archive.ics.uci.edu/ml/datasets/HIGGS) that this sample uses for training is hosted by the [UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/). We have provided a script that downloads and cleans the necessary files.
```
python data_download.py
```
This will download a file and store the processed file under the directory designated by `--data_dir` (defaults to `/tmp/higgs_data/`). To change the target directory, set the `--data_dir` flag. The directory could be network storages that Tensorflow supports (like Google Cloud Storage, `gs://<bucket>/<path>/`).
The file downloaded to the local temporary folder is about 2.8 GB, and the processed file is about 0.8 GB, so there should be enough storage to handle them.
### Training
This example uses about 3 GB of RAM during training.
You can run the code locally as follows:
```
python train_higgs.py
```
The model is by default saved to `/tmp/higgs_model`, which can be changed using the `--model_dir` flag.
Note that the model_dir is cleaned up before every time training starts.
Model parameters can be adjusted by flags, like `--n_trees`, `--max_depth`, `--learning_rate` and so on. Check out the code for details.
The final accuracy will be around 74% and loss will be around 0.516 over the eval set, when trained with the default parameters.
By default, the first 1 million examples among 11 millions are used for training, and the last 1 million examples are used for evaluation.
The training/evaluation data can be selected as index ranges by flags `--train_start`, `--train_count`, `--eval_start`, `--eval_count`, etc.
### TensorBoard
Run TensorBoard to inspect the details about the graph and training progression.
```
tensorboard --logdir=/tmp/higgs_model # set logdir as --model_dir set during training.
```
## Inference with SavedModel
You can export the model into Tensorflow [SavedModel](https://www.tensorflow.org/guide/saved_model) format by using the argument `--export_dir`:
```
python train_higgs.py --export_dir /tmp/higgs_boosted_trees_saved_model
```
After the model finishes training, use [`saved_model_cli`](https://www.tensorflow.org/guide/saved_model#cli_to_inspect_and_execute_savedmodel) to inspect and execute the SavedModel.
Try the following commands to inspect the SavedModel:
**Replace `${TIMESTAMP}` with the folder produced (e.g. 1524249124)**
```
# List possible tag_sets. Only one metagraph is saved, so there will be one option.
saved_model_cli show --dir /tmp/higgs_boosted_trees_saved_model/${TIMESTAMP}/
# Show SignatureDefs for tag_set=serve. SignatureDefs define the outputs to show.
saved_model_cli show --dir /tmp/higgs_boosted_trees_saved_model/${TIMESTAMP}/ \
--tag_set serve --all
```
### Inference
Let's use the model to predict the income group of two examples.
Note that this model exports SavedModel with the custom parsing module that accepts csv lines as features. (Each line is an example with 28 columns; be careful to not add a label column, unlike in the training data.)
```
saved_model_cli run --dir /tmp/boosted_trees_higgs_saved_model/${TIMESTAMP}/ \
--tag_set serve --signature_def="predict" \
--input_exprs='inputs=["0.869293,-0.635082,0.225690,0.327470,-0.689993,0.754202,-0.248573,-1.092064,0.0,1.374992,-0.653674,0.930349,1.107436,1.138904,-1.578198,-1.046985,0.0,0.657930,-0.010455,-0.045767,3.101961,1.353760,0.979563,0.978076,0.920005,0.721657,0.988751,0.876678", "1.595839,-0.607811,0.007075,1.818450,-0.111906,0.847550,-0.566437,1.581239,2.173076,0.755421,0.643110,1.426367,0.0,0.921661,-1.190432,-1.615589,0.0,0.651114,-0.654227,-1.274345,3.101961,0.823761,0.938191,0.971758,0.789176,0.430553,0.961357,0.957818"]'
```
This will print out the predicted classes and class probabilities. Something like:
```
Result for output key class_ids:
[[1]
[0]]
Result for output key classes:
[['1']
['0']]
Result for output key logistic:
[[0.6440273 ]
[0.10902369]]
Result for output key logits:
[[ 0.59288704]
[-2.1007526 ]]
Result for output key probabilities:
[[0.3559727 0.6440273]
[0.8909763 0.1090237]]
```
Please note that "predict" signature_def gives out different (more detailed) results than "classification" or "serving_default".
## Additional Links
If you are interested in distributed training, take a look at [Distributed TensorFlow](https://www.tensorflow.org/deploy/distributed).
You can also [train models on Cloud ML Engine](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction), which provides [hyperparameter tuning](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction#hyperparameter_tuning) to maximize your model's results and enables [deploying your model for prediction](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction#deploy_a_model_to_support_prediction).
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Downloads the UCI HIGGS Dataset and prepares train data.
The details on the dataset are in https://archive.ics.uci.edu/ml/datasets/HIGGS
It takes a while as it needs to download 2.8 GB over the network, process, then
store it into the specified location as a compressed numpy file.
Usage:
$ python data_download.py --data_dir=/tmp/higgs_data
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
import tempfile
# pylint: disable=g-bad-import-order
import numpy as np
import pandas as pd
from six.moves import urllib
from absl import app as absl_app
from absl import flags
import tensorflow as tf
from official.utils.flags import core as flags_core
URL_ROOT = "https://archive.ics.uci.edu/ml/machine-learning-databases/00280"
INPUT_FILE = "HIGGS.csv.gz"
NPZ_FILE = "HIGGS.csv.gz.npz" # numpy compressed file to contain "data" array.
def _download_higgs_data_and_save_npz(data_dir):
"""Download higgs data and store as a numpy compressed file."""
input_url = URL_ROOT + "/" + INPUT_FILE
np_filename = os.path.join(data_dir, NPZ_FILE)
if tf.gfile.Exists(np_filename):
raise ValueError("data_dir already has the processed data file: {}".format(
np_filename))
if not tf.gfile.Exists(data_dir):
tf.gfile.MkDir(data_dir)
# 2.8 GB to download.
try:
tf.logging.info("Data downloading...")
temp_filename, _ = urllib.request.urlretrieve(input_url)
# Reading and parsing 11 million csv lines takes 2~3 minutes.
tf.logging.info("Data processing... taking multiple minutes...")
with gzip.open(temp_filename, "rb") as csv_file:
data = pd.read_csv(
csv_file,
dtype=np.float32,
names=["c%02d" % i for i in range(29)] # label + 28 features.
).as_matrix()
finally:
tf.gfile.Remove(temp_filename)
# Writing to temporary location then copy to the data_dir (0.8 GB).
f = tempfile.NamedTemporaryFile()
np.savez_compressed(f, data=data)
tf.gfile.Copy(f.name, np_filename)
tf.logging.info("Data saved to: {}".format(np_filename))
def main(unused_argv):
if not tf.gfile.Exists(FLAGS.data_dir):
tf.gfile.MkDir(FLAGS.data_dir)
_download_higgs_data_and_save_npz(FLAGS.data_dir)
def define_data_download_flags():
"""Add flags specifying data download arguments."""
flags.DEFINE_string(
name="data_dir", default="/tmp/higgs_data",
help=flags_core.help_wrap(
"Directory to download higgs dataset and store training/eval data."))
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
define_data_download_flags()
FLAGS = flags.FLAGS
absl_app.run(main)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""A script that builds boosted trees over higgs data.
If you haven't, please run data_download.py beforehand to prepare the data.
For some more details on this example, please refer to README.md as well.
Note that the model_dir is cleaned up before starting the training.
Usage:
$ python train_higgs.py --n_trees=100 --max_depth=6 --learning_rate=0.1 \
--model_dir=/tmp/higgs_model
Note that BoostedTreesClassifier is available since Tensorflow 1.8.0.
So you need to install recent enough version of Tensorflow to use this example.
The training data is by default the first million examples out of 11M examples,
and eval data is by default the last million examples.
They are controlled by --train_start, --train_count, --eval_start, --eval_count.
e.g. to train over the first 10 million examples instead of 1 million:
$ python train_higgs.py --n_trees=100 --max_depth=6 --learning_rate=0.1 \
--model_dir=/tmp/higgs_model --train_count=10000000
Training history and metrics can be inspected using tensorboard.
Set --logdir as the --model_dir set by flag when training
(or the default /tmp/higgs_model).
$ tensorboard --logdir=/tmp/higgs_model
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# pylint: disable=g-bad-import-order
import numpy as np
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.utils.flags import core as flags_core
from official.utils.flags._conventions import help_wrap
from official.utils.logs import logger
NPZ_FILE = "HIGGS.csv.gz.npz" # numpy compressed file containing "data" array
def read_higgs_data(data_dir, train_start, train_count, eval_start, eval_count):
"""Reads higgs data from csv and returns train and eval data.
Args:
data_dir: A string, the directory of higgs dataset.
train_start: An integer, the start index of train examples within the data.
train_count: An integer, the number of train examples within the data.
eval_start: An integer, the start index of eval examples within the data.
eval_count: An integer, the number of eval examples within the data.
Returns:
Numpy array of train data and eval data.
"""
npz_filename = os.path.join(data_dir, NPZ_FILE)
try:
# gfile allows numpy to read data from network data sources as well.
with tf.gfile.Open(npz_filename, "rb") as npz_file:
with np.load(npz_file) as npz:
data = npz["data"]
except tf.errors.NotFoundError as e:
raise RuntimeError(
"Error loading data; use data_download.py to prepare the data.\n{}: {}"
.format(type(e).__name__, e))
return (data[train_start:train_start+train_count],
data[eval_start:eval_start+eval_count])
# This showcases how to make input_fn when the input data is available in the
# form of numpy arrays.
def make_inputs_from_np_arrays(features_np, label_np):
"""Makes and returns input_fn and feature_columns from numpy arrays.
The generated input_fn will return tf.data.Dataset of feature dictionary and a
label, and feature_columns will consist of the list of
tf.feature_column.BucketizedColumn.
Note, for in-memory training, tf.data.Dataset should contain the whole data
as a single tensor. Don't use batch.
Args:
features_np: A numpy ndarray (shape=[batch_size, num_features]) for
float32 features.
label_np: A numpy ndarray (shape=[batch_size, 1]) for labels.
Returns:
input_fn: A function returning a Dataset of feature dict and label.
feature_names: A list of feature names.
feature_column: A list of tf.feature_column.BucketizedColumn.
"""
num_features = features_np.shape[1]
features_np_list = np.split(features_np, num_features, axis=1)
# 1-based feature names.
feature_names = ["feature_%02d" % (i + 1) for i in range(num_features)]
# Create source feature_columns and bucketized_columns.
def get_bucket_boundaries(feature):
"""Returns bucket boundaries for feature by percentiles."""
return np.unique(np.percentile(feature, range(0, 100))).tolist()
source_columns = [
tf.feature_column.numeric_column(
feature_name, dtype=tf.float32,
# Although higgs data have no missing values, in general, default
# could be set as 0 or some reasonable value for missing values.
default_value=0.0)
for feature_name in feature_names
]
bucketized_columns = [
tf.feature_column.bucketized_column(
source_columns[i],
boundaries=get_bucket_boundaries(features_np_list[i]))
for i in range(num_features)
]
# Make an input_fn that extracts source features.
def input_fn():
"""Returns features as a dictionary of numpy arrays, and a label."""
features = {
feature_name: tf.constant(features_np_list[i])
for i, feature_name in enumerate(feature_names)
}
return tf.data.Dataset.zip((tf.data.Dataset.from_tensors(features),
tf.data.Dataset.from_tensors(label_np),))
return input_fn, feature_names, bucketized_columns
def make_eval_inputs_from_np_arrays(features_np, label_np):
"""Makes eval input as streaming batches."""
num_features = features_np.shape[1]
features_np_list = np.split(features_np, num_features, axis=1)
# 1-based feature names.
feature_names = ["feature_%02d" % (i + 1) for i in range(num_features)]
def input_fn():
features = {
feature_name: tf.constant(features_np_list[i])
for i, feature_name in enumerate(feature_names)
}
return tf.data.Dataset.zip((
tf.data.Dataset.from_tensor_slices(features),
tf.data.Dataset.from_tensor_slices(label_np),)).batch(1000)
return input_fn
def _make_csv_serving_input_receiver_fn(column_names, column_defaults):
"""Returns serving_input_receiver_fn for csv.
The input arguments are relevant to `tf.decode_csv()`.
Args:
column_names: a list of column names in the order within input csv.
column_defaults: a list of default values with the same size of
column_names. Each entity must be either a list of one scalar, or an
empty list to denote the corresponding column is required.
e.g. [[""], [2.5], []] indicates the third column is required while
the first column must be string and the second must be float/double.
Returns:
a serving_input_receiver_fn that handles csv for serving.
"""
def serving_input_receiver_fn():
csv = tf.placeholder(dtype=tf.string, shape=[None], name="csv")
features = dict(zip(column_names, tf.decode_csv(csv, column_defaults)))
receiver_tensors = {"inputs": csv}
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
return serving_input_receiver_fn
def train_boosted_trees(flags_obj):
"""Train boosted_trees estimator on HIGGS data.
Args:
flags_obj: An object containing parsed flag values.
"""
# Clean up the model directory if present.
if tf.gfile.Exists(flags_obj.model_dir):
tf.gfile.DeleteRecursively(flags_obj.model_dir)
tf.logging.info("## Data loading...")
train_data, eval_data = read_higgs_data(
flags_obj.data_dir, flags_obj.train_start, flags_obj.train_count,
flags_obj.eval_start, flags_obj.eval_count)
tf.logging.info("## Data loaded; train: {}{}, eval: {}{}".format(
train_data.dtype, train_data.shape, eval_data.dtype, eval_data.shape))
# Data consists of one label column followed by 28 feature columns.
train_input_fn, feature_names, feature_columns = make_inputs_from_np_arrays(
features_np=train_data[:, 1:], label_np=train_data[:, 0:1])
eval_input_fn = make_eval_inputs_from_np_arrays(
features_np=eval_data[:, 1:], label_np=eval_data[:, 0:1])
tf.logging.info("## Features prepared. Training starts...")
# Create benchmark logger to log info about the training and metric values
run_params = {
"train_start": flags_obj.train_start,
"train_count": flags_obj.train_count,
"eval_start": flags_obj.eval_start,
"eval_count": flags_obj.eval_count,
"n_trees": flags_obj.n_trees,
"max_depth": flags_obj.max_depth,
}
benchmark_logger = logger.config_benchmark_logger(flags_obj)
benchmark_logger.log_run_info(
model_name="boosted_trees",
dataset_name="higgs",
run_params=run_params,
test_id=flags_obj.benchmark_test_id)
# Though BoostedTreesClassifier is under tf.estimator, faster in-memory
# training is yet provided as a contrib library.
classifier = tf.contrib.estimator.boosted_trees_classifier_train_in_memory(
train_input_fn,
feature_columns,
model_dir=flags_obj.model_dir or None,
n_trees=flags_obj.n_trees,
max_depth=flags_obj.max_depth,
learning_rate=flags_obj.learning_rate)
# Evaluation.
eval_results = classifier.evaluate(eval_input_fn)
# Benchmark the evaluation results
benchmark_logger.log_evaluation_result(eval_results)
# Exporting the savedmodel with csv parsing.
if flags_obj.export_dir is not None:
classifier.export_savedmodel(
flags_obj.export_dir,
_make_csv_serving_input_receiver_fn(
column_names=feature_names,
# columns are all floats.
column_defaults=[[0.0]] * len(feature_names)),
strip_default_attrs=True)
def main(_):
train_boosted_trees(flags.FLAGS)
def define_train_higgs_flags():
"""Add tree related flags as well as training/eval configuration."""
flags_core.define_base(clean=False, stop_threshold=False, batch_size=False,
num_gpu=False)
flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core)
flags.DEFINE_integer(
name="train_start", default=0,
help=help_wrap("Start index of train examples within the data."))
flags.DEFINE_integer(
name="train_count", default=1000000,
help=help_wrap("Number of train examples within the data."))
flags.DEFINE_integer(
name="eval_start", default=10000000,
help=help_wrap("Start index of eval examples within the data."))
flags.DEFINE_integer(
name="eval_count", default=1000000,
help=help_wrap("Number of eval examples within the data."))
flags.DEFINE_integer(
"n_trees", default=100, help=help_wrap("Number of trees to build."))
flags.DEFINE_integer(
"max_depth", default=6, help=help_wrap("Maximum depths of each tree."))
flags.DEFINE_float(
"learning_rate", default=0.1,
help=help_wrap("The learning rate."))
flags_core.set_defaults(data_dir="/tmp/higgs_data",
model_dir="/tmp/higgs_model")
if __name__ == "__main__":
# Training progress and eval results are shown as logging.INFO; so enables it.
tf.logging.set_verbosity(tf.logging.INFO)
define_train_higgs_flags()
absl_app.run(main)
1.000000000000000000e+00,8.692932128906250000e-01,-6.350818276405334473e-01,2.256902605295181274e-01,3.274700641632080078e-01,-6.899932026863098145e-01,7.542022466659545898e-01,-2.485731393098831177e-01,-1.092063903808593750e+00,0.000000000000000000e+00,1.374992132186889648e+00,-6.536741852760314941e-01,9.303491115570068359e-01,1.107436060905456543e+00,1.138904333114624023e+00,-1.578198313713073730e+00,-1.046985387802124023e+00,0.000000000000000000e+00,6.579295396804809570e-01,-1.045456994324922562e-02,-4.576716944575309753e-02,3.101961374282836914e+00,1.353760004043579102e+00,9.795631170272827148e-01,9.780761599540710449e-01,9.200048446655273438e-01,7.216574549674987793e-01,9.887509346008300781e-01,8.766783475875854492e-01
1.000000000000000000e+00,9.075421094894409180e-01,3.291472792625427246e-01,3.594118654727935791e-01,1.497969865798950195e+00,-3.130095303058624268e-01,1.095530629158020020e+00,-5.575249195098876953e-01,-1.588229775428771973e+00,2.173076152801513672e+00,8.125811815261840820e-01,-2.136419266462326050e-01,1.271014571189880371e+00,2.214872121810913086e+00,4.999939501285552979e-01,-1.261431813240051270e+00,7.321561574935913086e-01,0.000000000000000000e+00,3.987008929252624512e-01,-1.138930082321166992e+00,-8.191101951524615288e-04,0.000000000000000000e+00,3.022198975086212158e-01,8.330481648445129395e-01,9.856996536254882812e-01,9.780983924865722656e-01,7.797321677207946777e-01,9.923557639122009277e-01,7.983425855636596680e-01
1.000000000000000000e+00,7.988347411155700684e-01,1.470638751983642578e+00,-1.635974764823913574e+00,4.537731707096099854e-01,4.256291687488555908e-01,1.104874610900878906e+00,1.282322287559509277e+00,1.381664276123046875e+00,0.000000000000000000e+00,8.517372012138366699e-01,1.540658950805664062e+00,-8.196895122528076172e-01,2.214872121810913086e+00,9.934899210929870605e-01,3.560801148414611816e-01,-2.087775468826293945e-01,2.548224449157714844e+00,1.256954550743103027e+00,1.128847599029541016e+00,9.004608392715454102e-01,0.000000000000000000e+00,9.097532629966735840e-01,1.108330488204956055e+00,9.856922030448913574e-01,9.513312578201293945e-01,8.032515048980712891e-01,8.659244179725646973e-01,7.801175713539123535e-01
0.000000000000000000e+00,1.344384789466857910e+00,-8.766260147094726562e-01,9.359127283096313477e-01,1.992050051689147949e+00,8.824543952941894531e-01,1.786065936088562012e+00,-1.646777749061584473e+00,-9.423825144767761230e-01,0.000000000000000000e+00,2.423264741897583008e+00,-6.760157942771911621e-01,7.361586689949035645e-01,2.214872121810913086e+00,1.298719763755798340e+00,-1.430738091468811035e+00,-3.646581768989562988e-01,0.000000000000000000e+00,7.453126907348632812e-01,-6.783788204193115234e-01,-1.360356330871582031e+00,0.000000000000000000e+00,9.466524720191955566e-01,1.028703689575195312e+00,9.986560940742492676e-01,7.282806038856506348e-01,8.692002296447753906e-01,1.026736497879028320e+00,9.579039812088012695e-01
1.000000000000000000e+00,1.105008959770202637e+00,3.213555514812469482e-01,1.522401213645935059e+00,8.828076124191284180e-01,-1.205349326133728027e+00,6.814661026000976562e-01,-1.070463895797729492e+00,-9.218706488609313965e-01,0.000000000000000000e+00,8.008721470832824707e-01,1.020974040031433105e+00,9.714065194129943848e-01,2.214872121810913086e+00,5.967612862586975098e-01,-3.502728641033172607e-01,6.311942934989929199e-01,0.000000000000000000e+00,4.799988865852355957e-01,-3.735655248165130615e-01,1.130406111478805542e-01,0.000000000000000000e+00,7.558564543724060059e-01,1.361057043075561523e+00,9.866096973419189453e-01,8.380846381187438965e-01,1.133295178413391113e+00,8.722448945045471191e-01,8.084865212440490723e-01
0.000000000000000000e+00,1.595839262008666992e+00,-6.078106760978698730e-01,7.074915803968906403e-03,1.818449616432189941e+00,-1.119059920310974121e-01,8.475499153137207031e-01,-5.664370059967041016e-01,1.581239342689514160e+00,2.173076152801513672e+00,7.554209828376770020e-01,6.431096196174621582e-01,1.426366806030273438e+00,0.000000000000000000e+00,9.216607809066772461e-01,-1.190432429313659668e+00,-1.615589022636413574e+00,0.000000000000000000e+00,6.511141061782836914e-01,-6.542269587516784668e-01,-1.274344921112060547e+00,3.101961374282836914e+00,8.237605690956115723e-01,9.381914138793945312e-01,9.717581868171691895e-01,7.891763448715209961e-01,4.305532872676849365e-01,9.613569378852844238e-01,9.578179121017456055e-01
1.000000000000000000e+00,4.093913435935974121e-01,-1.884683609008789062e+00,-1.027292013168334961e+00,1.672451734542846680e+00,-1.604598283767700195e+00,1.338014960289001465e+00,5.542744323611259460e-02,1.346588134765625000e-02,2.173076152801513672e+00,5.097832679748535156e-01,-1.038338065147399902e+00,7.078623175621032715e-01,0.000000000000000000e+00,7.469175457954406738e-01,-3.584651052951812744e-01,-1.646654248237609863e+00,0.000000000000000000e+00,3.670579791069030762e-01,6.949646025896072388e-02,1.377130270004272461e+00,3.101961374282836914e+00,8.694183826446533203e-01,1.222082972526550293e+00,1.000627398490905762e+00,5.450449585914611816e-01,6.986525058746337891e-01,9.773144721984863281e-01,8.287860751152038574e-01
1.000000000000000000e+00,9.338953495025634766e-01,6.291297078132629395e-01,5.275348424911499023e-01,2.380327433347702026e-01,-9.665691256523132324e-01,5.478111505508422852e-01,-5.943922698497772217e-02,-1.706866145133972168e+00,2.173076152801513672e+00,9.410027265548706055e-01,-2.653732776641845703e+00,-1.572199910879135132e-01,0.000000000000000000e+00,1.030370354652404785e+00,-1.755051016807556152e-01,5.230209231376647949e-01,2.548224449157714844e+00,1.373546600341796875e+00,1.291248083114624023e+00,-1.467454433441162109e+00,0.000000000000000000e+00,9.018372893333435059e-01,1.083671212196350098e+00,9.796960949897766113e-01,7.833003997802734375e-01,8.491951823234558105e-01,8.943563103675842285e-01,7.748793959617614746e-01
1.000000000000000000e+00,1.405143737792968750e+00,5.366026163101196289e-01,6.895543336868286133e-01,1.179567337036132812e+00,-1.100611537694931030e-01,3.202404975891113281e+00,-1.526960015296936035e+00,-1.576033473014831543e+00,0.000000000000000000e+00,2.931536912918090820e+00,5.673424601554870605e-01,-1.300333440303802490e-01,2.214872121810913086e+00,1.787122726440429688e+00,8.994985818862915039e-01,5.851513147354125977e-01,2.548224449157714844e+00,4.018652141094207764e-01,-1.512016952037811279e-01,1.163489103317260742e+00,0.000000000000000000e+00,1.667070508003234863e+00,4.039272785186767578e+00,1.175828456878662109e+00,1.045351743698120117e+00,1.542971968650817871e+00,3.534826755523681641e+00,2.740753889083862305e+00
1.000000000000000000e+00,1.176565527915954590e+00,1.041605025529861450e-01,1.397002458572387695e+00,4.797213077545166016e-01,2.655133903026580811e-01,1.135563015937805176e+00,1.534830927848815918e+00,-2.532912194728851318e-01,0.000000000000000000e+00,1.027246594429016113e+00,5.343157649040222168e-01,1.180022358894348145e+00,0.000000000000000000e+00,2.405661106109619141e+00,8.755676448345184326e-02,-9.765340685844421387e-01,2.548224449157714844e+00,1.250382542610168457e+00,2.685412168502807617e-01,5.303344726562500000e-01,0.000000000000000000e+00,8.331748843193054199e-01,7.739681005477905273e-01,9.857499599456787109e-01,1.103696346282958984e+00,8.491398692131042480e-01,9.371039867401123047e-01,8.123638033866882324e-01
1.000000000000000000e+00,9.459739923477172852e-01,1.111244320869445801e+00,1.218337059020996094e+00,9.076390862464904785e-01,8.215369582176208496e-01,1.153243303298950195e+00,-3.654202818870544434e-01,-1.566054821014404297e+00,0.000000000000000000e+00,7.447192072868347168e-01,7.208195328712463379e-01,-3.758229315280914307e-01,2.214872121810913086e+00,6.088791489601135254e-01,3.078369498252868652e-01,-1.281638383865356445e+00,0.000000000000000000e+00,1.597967982292175293e+00,-4.510180354118347168e-01,6.365344673395156860e-02,3.101961374282836914e+00,8.290241360664367676e-01,9.806482791900634766e-01,9.943597912788391113e-01,9.082478284835815430e-01,7.758789062500000000e-01,7.833113670349121094e-01,7.251217961311340332e-01
0.000000000000000000e+00,7.393567562103271484e-01,-1.782904267311096191e-01,8.299342393875122070e-01,5.045390725135803223e-01,-1.302167475223541260e-01,9.610513448715209961e-01,-3.555179834365844727e-01,-1.717399358749389648e+00,2.173076152801513672e+00,6.209560632705688477e-01,-4.817410409450531006e-01,-1.199193239212036133e+00,0.000000000000000000e+00,9.826014041900634766e-01,8.118502795696258545e-02,-2.903236448764801025e-01,0.000000000000000000e+00,1.064662933349609375e+00,7.740649580955505371e-01,3.988203406333923340e-01,3.101961374282836914e+00,9.445360302925109863e-01,1.026260614395141602e+00,9.821967482566833496e-01,5.421146750450134277e-01,1.250978946685791016e+00,8.300446271896362305e-01,7.613079547882080078e-01
1.000000000000000000e+00,1.384097695350646973e+00,1.168220937252044678e-01,-1.179878950119018555e+00,7.629125714302062988e-01,-7.978226989507675171e-02,1.019863128662109375e+00,8.773182630538940430e-01,1.276887178421020508e+00,2.173076152801513672e+00,3.312520980834960938e-01,1.409523487091064453e+00,-1.474388837814331055e+00,0.000000000000000000e+00,1.282738208770751953e+00,7.374743819236755371e-01,-2.254196107387542725e-01,0.000000000000000000e+00,1.559753060340881348e+00,8.465205430984497070e-01,5.048085451126098633e-01,3.101961374282836914e+00,9.593246579170227051e-01,8.073760271072387695e-01,1.191813588142395020e+00,1.221210360527038574e+00,8.611412644386291504e-01,9.293408989906311035e-01,8.383023738861083984e-01
1.000000000000000000e+00,1.383548736572265625e+00,8.891792893409729004e-01,6.185320615768432617e-01,1.081547021865844727e+00,3.446055650711059570e-01,9.563793540000915527e-01,8.545429706573486328e-01,-1.129207015037536621e+00,2.173076152801513672e+00,5.456657409667968750e-01,-3.078651726245880127e-01,-6.232798099517822266e-01,2.214872121810913086e+00,3.482571244239807129e-01,1.024202585220336914e+00,1.840776652097702026e-01,0.000000000000000000e+00,7.813369035720825195e-01,-1.636125564575195312e+00,1.144067287445068359e+00,0.000000000000000000e+00,5.222384929656982422e-01,7.376385331153869629e-01,9.861995577812194824e-01,1.349615693092346191e+00,8.127878904342651367e-01,9.534064531326293945e-01,7.797226309776306152e-01
1.000000000000000000e+00,1.343652725219726562e+00,8.385329246520996094e-01,-1.061138510704040527e+00,2.472015142440795898e+00,-5.726317167282104492e-01,1.512709975242614746e+00,1.143690109252929688e+00,8.555619716644287109e-01,0.000000000000000000e+00,8.842203021049499512e-01,1.474605560302734375e+00,-1.360648751258850098e+00,1.107436060905456543e+00,1.587265610694885254e+00,2.234833478927612305e+00,7.756848633289337158e-02,0.000000000000000000e+00,1.609408140182495117e+00,2.396404743194580078e+00,7.572935223579406738e-01,0.000000000000000000e+00,9.340201020240783691e-01,8.447072505950927734e-01,1.077844023704528809e+00,1.400183677673339844e+00,9.477745294570922852e-01,1.007614254951477051e+00,9.010174870491027832e-01
0.000000000000000000e+00,5.470141768455505371e-01,-3.497089445590972900e-01,-6.466571688652038574e-01,2.040462255477905273e+00,2.764569818973541260e-01,5.446965098381042480e-01,8.386992812156677246e-01,1.728703141212463379e+00,0.000000000000000000e+00,6.528096199035644531e-01,1.471691370010375977e+00,1.243273019790649414e+00,0.000000000000000000e+00,7.857298851013183594e-01,-4.442929103970527649e-02,-1.019803404808044434e+00,2.548224449157714844e+00,4.191471040248870850e-01,-6.292421817779541016e-01,1.570794582366943359e+00,3.101961374282836914e+00,6.894335746765136719e-01,8.672295808792114258e-01,1.082487821578979492e+00,6.641419529914855957e-01,3.541145622730255127e-01,5.799450278282165527e-01,8.172734379768371582e-01
1.000000000000000000e+00,1.484203696250915527e+00,1.699521422386169434e+00,-1.059473991394042969e+00,2.700195550918579102e+00,-1.055963873863220215e+00,2.409452915191650391e+00,4.574607908725738525e-01,3.449823260307312012e-01,0.000000000000000000e+00,1.414903521537780762e+00,1.114225864410400391e+00,-1.448866605758666992e+00,0.000000000000000000e+00,1.012983918190002441e+00,-2.056988954544067383e+00,1.131010890007019043e+00,0.000000000000000000e+00,9.054746031761169434e-01,2.182368993759155273e+00,1.043073177337646484e+00,0.000000000000000000e+00,1.653626322746276855e+00,9.935762286186218262e-01,9.833217859268188477e-01,7.413797974586486816e-01,1.633816361427307129e-01,5.923243165016174316e-01,7.451378703117370605e-01
0.000000000000000000e+00,1.057975649833679199e+00,-1.607590019702911377e-01,-1.949972510337829590e-01,2.705023050308227539e+00,-7.514767050743103027e-01,1.909918904304504395e+00,-1.031844973564147949e+00,8.649863600730895996e-01,0.000000000000000000e+00,1.300834894180297852e+00,1.467376798391342163e-01,-1.118742942810058594e+00,1.107436060905456543e+00,9.669710993766784668e-01,-3.666573464870452881e-01,1.108266711235046387e+00,0.000000000000000000e+00,5.547249317169189453e-01,-7.141901850700378418e-01,1.505314946174621582e+00,3.101961374282836914e+00,9.544943571090698242e-01,6.510385870933532715e-01,1.124949693679809570e+00,8.940010070800781250e-01,6.721734404563903809e-01,1.182358264923095703e+00,1.316304087638854980e+00
0.000000000000000000e+00,6.753035783767700195e-01,1.120983958244323730e+00,-2.804459035396575928e-01,1.539554953575134277e+00,7.345175743103027344e-01,6.146844029426574707e-01,-5.070231556892395020e-01,7.945806980133056641e-01,2.173076152801513672e+00,2.188202738761901855e-01,-1.894118309020996094e+00,-5.805578827857971191e-01,0.000000000000000000e+00,1.245682120323181152e+00,-3.475421071052551270e-01,-8.561564683914184570e-01,2.548224449157714844e+00,7.531017661094665527e-01,-1.145592689514160156e+00,-1.374783992767333984e+00,0.000000000000000000e+00,9.069401025772094727e-01,8.983390927314758301e-01,1.119651079177856445e+00,1.269073486328125000e+00,1.088765859603881836e+00,1.015413045883178711e+00,9.146358966827392578e-01
1.000000000000000000e+00,6.427279114723205566e-01,-1.429840326309204102e+00,1.519071936607360840e+00,9.409985542297363281e-01,8.872274160385131836e-01,1.615126848220825195e+00,-1.336835741996765137e+00,-2.665962278842926025e-01,1.086538076400756836e+00,1.667088270187377930e+00,6.557375192642211914e-01,-1.588128924369812012e+00,0.000000000000000000e+00,8.282302021980285645e-01,1.836144566535949707e+00,4.081907570362091064e-01,0.000000000000000000e+00,1.708718180656433105e+00,-3.469151556491851807e-01,-1.182784557342529297e+00,3.101961374282836914e+00,9.210902452468872070e-01,1.373361706733703613e+00,9.849172830581665039e-01,1.422878146171569824e+00,1.546551108360290527e+00,1.782585501670837402e+00,1.438173770904541016e+00
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for boosted_tree."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
import numpy as np
import pandas as pd
import tensorflow as tf
# pylint: disable=g-bad-import-order
from official.boosted_trees import train_higgs
from official.utils.testing import integration
TEST_CSV = os.path.join(os.path.dirname(__file__), "train_higgs_test.csv")
tf.logging.set_verbosity(tf.logging.ERROR)
class BaseTest(tf.test.TestCase):
"""Tests for Wide Deep model."""
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(BaseTest, cls).setUpClass()
train_higgs.define_train_higgs_flags()
def setUp(self):
# Create temporary CSV file
self.data_dir = self.get_temp_dir()
data = pd.read_csv(
TEST_CSV, dtype=np.float32, names=["c%02d" % i for i in range(29)]
).as_matrix()
self.input_npz = os.path.join(self.data_dir, train_higgs.NPZ_FILE)
# numpy.savez doesn't take gfile.Gfile, so need to write down and copy.
tmpfile = tempfile.NamedTemporaryFile()
np.savez_compressed(tmpfile, data=data)
tf.gfile.Copy(tmpfile.name, self.input_npz)
def test_read_higgs_data(self):
"""Tests read_higgs_data() function."""
# Error when a wrong data_dir is given.
with self.assertRaisesRegexp(RuntimeError, "Error loading data.*"):
train_data, eval_data = train_higgs.read_higgs_data(
self.data_dir + "non-existing-path",
train_start=0, train_count=15, eval_start=15, eval_count=5)
# Loading fine with the correct data_dir.
train_data, eval_data = train_higgs.read_higgs_data(
self.data_dir,
train_start=0, train_count=15, eval_start=15, eval_count=5)
self.assertEqual((15, 29), train_data.shape)
self.assertEqual((5, 29), eval_data.shape)
def test_make_inputs_from_np_arrays(self):
"""Tests make_inputs_from_np_arrays() function."""
train_data, _ = train_higgs.read_higgs_data(
self.data_dir,
train_start=0, train_count=15, eval_start=15, eval_count=5)
(input_fn, feature_names,
feature_columns) = train_higgs.make_inputs_from_np_arrays(
features_np=train_data[:, 1:], label_np=train_data[:, 0:1])
# Check feature_names.
self.assertAllEqual(feature_names,
["feature_%02d" % (i+1) for i in range(28)])
# Check feature columns.
self.assertEqual(28, len(feature_columns))
bucketized_column_type = type(
tf.feature_column.bucketized_column(
tf.feature_column.numeric_column("feature_01"),
boundaries=[0, 1, 2])) # dummy boundaries.
for feature_column in feature_columns:
self.assertIsInstance(feature_column, bucketized_column_type)
# At least 2 boundaries.
self.assertGreaterEqual(len(feature_column.boundaries), 2)
# Tests that the source column names of the bucketized columns match.
self.assertAllEqual(feature_names,
[col.source_column.name for col in feature_columns])
# Check features.
features, labels = input_fn().make_one_shot_iterator().get_next()
with tf.Session() as sess:
features, labels = sess.run((features, labels))
self.assertIsInstance(features, dict)
self.assertAllEqual(feature_names, sorted(features.keys()))
self.assertAllEqual([[15, 1]] * 28,
[features[name].shape for name in feature_names])
# Validate actual values of some features.
self.assertAllClose(
[0.869293, 0.907542, 0.798834, 1.344384, 1.105009, 1.595839,
0.409391, 0.933895, 1.405143, 1.176565, 0.945974, 0.739356,
1.384097, 1.383548, 1.343652],
np.squeeze(features[feature_names[0]], 1))
self.assertAllClose(
[-0.653674, -0.213641, 1.540659, -0.676015, 1.020974, 0.643109,
-1.038338, -2.653732, 0.567342, 0.534315, 0.720819, -0.481741,
1.409523, -0.307865, 1.474605],
np.squeeze(features[feature_names[10]], 1))
def test_end_to_end(self):
"""Tests end-to-end running."""
model_dir = os.path.join(self.get_temp_dir(), "model")
integration.run_synthetic(
main=train_higgs.main, tmp_root=self.get_temp_dir(), extra_flags=[
"--data_dir", self.data_dir,
"--model_dir", model_dir,
"--n_trees", "5",
"--train_start", "0",
"--train_count", "12",
"--eval_start", "12",
"--eval_count", "8",
],
synth=False, max_train=None)
self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, "checkpoint")))
def test_end_to_end_with_export(self):
"""Tests end-to-end running."""
model_dir = os.path.join(self.get_temp_dir(), "model")
export_dir = os.path.join(self.get_temp_dir(), "export")
integration.run_synthetic(
main=train_higgs.main, tmp_root=self.get_temp_dir(), extra_flags=[
"--data_dir", self.data_dir,
"--model_dir", model_dir,
"--export_dir", export_dir,
"--n_trees", "5",
"--train_start", "0",
"--train_count", "12",
"--eval_start", "12",
"--eval_count", "8",
],
synth=False, max_train=None)
self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, "checkpoint")))
self.assertTrue(tf.gfile.Exists(os.path.join(export_dir)))
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Download and extract the MovieLens dataset from GroupLens website.
Download the dataset, and perform basic preprocessing.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import tempfile
import zipfile
# pylint: disable=g-bad-import-order
import numpy as np
import pandas as pd
import six
from six.moves import urllib # pylint: disable=redefined-builtin
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.utils.flags import core as flags_core
ML_1M = "ml-1m"
ML_20M = "ml-20m"
DATASETS = [ML_1M, ML_20M]
RATINGS_FILE = "ratings.csv"
MOVIES_FILE = "movies.csv"
# URL to download dataset
_DATA_URL = "http://files.grouplens.org/datasets/movielens/"
GENRE_COLUMN = "genres"
ITEM_COLUMN = "item_id" # movies
RATING_COLUMN = "rating"
TIMESTAMP_COLUMN = "timestamp"
TITLE_COLUMN = "titles"
USER_COLUMN = "user_id"
GENRES = [
'Action', 'Adventure', 'Animation', "Children", 'Comedy', 'Crime',
'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', "IMAX", 'Musical',
'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'War', 'Western'
]
N_GENRE = len(GENRES)
RATING_COLUMNS = [USER_COLUMN, ITEM_COLUMN, RATING_COLUMN, TIMESTAMP_COLUMN]
MOVIE_COLUMNS = [ITEM_COLUMN, TITLE_COLUMN, GENRE_COLUMN]
# Note: Users are indexed [1, k], not [0, k-1]
NUM_USER_IDS = {
ML_1M: 6040,
ML_20M: 138493,
}
# Note: Movies are indexed [1, k], not [0, k-1]
# Both the 1m and 20m datasets use the same movie set.
NUM_ITEM_IDS = 3952
MAX_RATING = 5
NUM_RATINGS = {
ML_1M: 1000209,
ML_20M: 20000263
}
def _download_and_clean(dataset, data_dir):
"""Download MovieLens dataset in a standard format.
This function downloads the specified MovieLens format and coerces it into a
standard format. The only difference between the ml-1m and ml-20m datasets
after this point (other than size, of course) is that the 1m dataset uses
whole number ratings while the 20m dataset allows half integer ratings.
"""
if dataset not in DATASETS:
raise ValueError("dataset {} is not in {{{}}}".format(
dataset, ",".join(DATASETS)))
data_subdir = os.path.join(data_dir, dataset)
expected_files = ["{}.zip".format(dataset), RATINGS_FILE, MOVIES_FILE]
tf.gfile.MakeDirs(data_subdir)
if set(expected_files).intersection(
tf.gfile.ListDirectory(data_subdir)) == set(expected_files):
tf.logging.info("Dataset {} has already been downloaded".format(dataset))
return
url = "{}{}.zip".format(_DATA_URL, dataset)
temp_dir = tempfile.mkdtemp()
try:
zip_path = os.path.join(temp_dir, "{}.zip".format(dataset))
zip_path, _ = urllib.request.urlretrieve(url, zip_path)
statinfo = os.stat(zip_path)
# A new line to clear the carriage return from download progress
# tf.logging.info is not applicable here
print()
tf.logging.info(
"Successfully downloaded {} {} bytes".format(
zip_path, statinfo.st_size))
zipfile.ZipFile(zip_path, "r").extractall(temp_dir)
if dataset == ML_1M:
_regularize_1m_dataset(temp_dir)
else:
_regularize_20m_dataset(temp_dir)
for fname in tf.gfile.ListDirectory(temp_dir):
if not tf.gfile.Exists(os.path.join(data_subdir, fname)):
tf.gfile.Copy(os.path.join(temp_dir, fname),
os.path.join(data_subdir, fname))
else:
tf.logging.info("Skipping copy of {}, as it already exists in the "
"destination folder.".format(fname))
finally:
tf.gfile.DeleteRecursively(temp_dir)
def _transform_csv(input_path, output_path, names, skip_first, separator=","):
"""Transform csv to a regularized format.
Args:
input_path: The path of the raw csv.
output_path: The path of the cleaned csv.
names: The csv column names.
skip_first: Boolean of whether to skip the first line of the raw csv.
separator: Character used to separate fields in the raw csv.
"""
if six.PY2:
names = [n.decode("utf-8") for n in names]
with tf.gfile.Open(output_path, "wb") as f_out, \
tf.gfile.Open(input_path, "rb") as f_in:
# Write column names to the csv.
f_out.write(",".join(names).encode("utf-8"))
f_out.write(b"\n")
for i, line in enumerate(f_in):
if i == 0 and skip_first:
continue # ignore existing labels in the csv
line = line.decode("utf-8", errors="ignore")
fields = line.split(separator)
if separator != ",":
fields = ['"{}"'.format(field) if "," in field else field
for field in fields]
f_out.write(",".join(fields).encode("utf-8"))
def _regularize_1m_dataset(temp_dir):
"""
ratings.dat
The file has no header row, and each line is in the following format:
UserID::MovieID::Rating::Timestamp
- UserIDs range from 1 and 6040
- MovieIDs range from 1 and 3952
- Ratings are made on a 5-star scale (whole-star ratings only)
- Timestamp is represented in seconds since midnight Coordinated Universal
Time (UTC) of January 1, 1970.
- Each user has at least 20 ratings
movies.dat
Each line has the following format:
MovieID::Title::Genres
- MovieIDs range from 1 and 3952
"""
working_dir = os.path.join(temp_dir, ML_1M)
_transform_csv(
input_path=os.path.join(working_dir, "ratings.dat"),
output_path=os.path.join(temp_dir, RATINGS_FILE),
names=RATING_COLUMNS, skip_first=False, separator="::")
_transform_csv(
input_path=os.path.join(working_dir, "movies.dat"),
output_path=os.path.join(temp_dir, MOVIES_FILE),
names=MOVIE_COLUMNS, skip_first=False, separator="::")
tf.gfile.DeleteRecursively(working_dir)
def _regularize_20m_dataset(temp_dir):
"""
ratings.csv
Each line of this file after the header row represents one rating of one
movie by one user, and has the following format:
userId,movieId,rating,timestamp
- The lines within this file are ordered first by userId, then, within user,
by movieId.
- Ratings are made on a 5-star scale, with half-star increments
(0.5 stars - 5.0 stars).
- Timestamps represent seconds since midnight Coordinated Universal Time
(UTC) of January 1, 1970.
- All the users had rated at least 20 movies.
movies.csv
Each line has the following format:
MovieID,Title,Genres
- MovieIDs range from 1 and 3952
"""
working_dir = os.path.join(temp_dir, ML_20M)
_transform_csv(
input_path=os.path.join(working_dir, "ratings.csv"),
output_path=os.path.join(temp_dir, RATINGS_FILE),
names=RATING_COLUMNS, skip_first=True, separator=",")
_transform_csv(
input_path=os.path.join(working_dir, "movies.csv"),
output_path=os.path.join(temp_dir, MOVIES_FILE),
names=MOVIE_COLUMNS, skip_first=True, separator=",")
tf.gfile.DeleteRecursively(working_dir)
def download(dataset, data_dir):
if dataset:
_download_and_clean(dataset, data_dir)
else:
_ = [_download_and_clean(d, data_dir) for d in DATASETS]
def ratings_csv_to_dataframe(data_dir, dataset):
with tf.gfile.Open(os.path.join(data_dir, dataset, RATINGS_FILE)) as f:
return pd.read_csv(f, encoding="utf-8")
def csv_to_joint_dataframe(data_dir, dataset):
ratings = ratings_csv_to_dataframe(data_dir, dataset)
with tf.gfile.Open(os.path.join(data_dir, dataset, MOVIES_FILE)) as f:
movies = pd.read_csv(f, encoding="utf-8")
df = ratings.merge(movies, on=ITEM_COLUMN)
df[RATING_COLUMN] = df[RATING_COLUMN].astype(np.float32)
return df
def integerize_genres(dataframe):
"""Replace genre string with a binary vector.
Args:
dataframe: a pandas dataframe of movie data.
Returns:
The transformed dataframe.
"""
def _map_fn(entry):
entry.replace("Children's", "Children") # naming difference.
movie_genres = entry.split("|")
output = np.zeros((len(GENRES),), dtype=np.int64)
for i, genre in enumerate(GENRES):
if genre in movie_genres:
output[i] = 1
return output
dataframe[GENRE_COLUMN] = dataframe[GENRE_COLUMN].apply(_map_fn)
return dataframe
def define_data_download_flags():
"""Add flags specifying data download arguments."""
flags.DEFINE_string(
name="data_dir", default="/tmp/movielens-data/",
help=flags_core.help_wrap(
"Directory to download and extract data."))
flags.DEFINE_enum(
name="dataset", default=None,
enum_values=DATASETS, case_sensitive=False,
help=flags_core.help_wrap("Dataset to be trained and evaluated."))
def main(_):
"""Download and extract the data from GroupLens website."""
download(flags.FLAGS.dataset, flags.FLAGS.data_dir)
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
define_data_download_flags()
FLAGS = flags.FLAGS
absl_app.run(main)
# Keras Application Models Benchmark
## Overview
This provides a single scaffold to benchmark the Keras built-in application [models](https://keras.io/applications/). All the models are for image classification applications, and include:
- Xception
- VGG16
- VGG19
- ResNet50
- InceptionV3
- InceptionResNetV2
- MobileNet
- DenseNet
- NASNet
## Dataset
Synthetic dataset is used for the benchmark.
## Callbacks
Two custom callbacks are provided for model benchmarking: ExamplesPerSecondCallback and LoggingMetricCallback. For each callback, `epoch_based` and `batch_based` options are available to set the benchmark level. Check [model_callbacks.py](model_callbacks.py) for more details.
## Running Code
To benchmark a model, use `--model` to specify the model name. To perform the benchmark with eager execution, issue the following command:
```
python benchmark_main.py --model resnet50 --eager
```
Note that, if eager execution is enabled, only one GPU is utilized even if multiple GPUs are provided and multi_gpu_model is used.
To use distribution strategy in the benchmark, run the following:
```
python benchmark_main.py --model resnet50 --dist_strat
```
Currently, only one of the --eager and --dist_strat arguments can be defined, as DistributionStrategy is not supported in Eager execution now.
Arguments:
* `--model`: Which model to be benchmarked. The model name is defined as the keys of `MODELS` in [benchmark_main.py](benchmark_main.py).
* `--callbacks`: To specify a list of callbacks.
Use the `--help` or `-h` flag to get a full list of possible arguments.
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Benchmark on the keras built-in application models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=g-bad-import-order
import numpy as np
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.keras_application_models import dataset
from official.keras_application_models import model_callbacks
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
# Define a dictionary that maps model names to their model classes inside Keras
MODELS = {
"vgg16": tf.keras.applications.VGG16,
"vgg19": tf.keras.applications.VGG19,
"inceptionv3": tf.keras.applications.InceptionV3,
"xception": tf.keras.applications.Xception,
"resnet50": tf.keras.applications.ResNet50,
"inceptionresnetv2": tf.keras.applications.InceptionResNetV2,
"mobilenet": tf.keras.applications.MobileNet,
"densenet121": tf.keras.applications.DenseNet121,
"densenet169": tf.keras.applications.DenseNet169,
"densenet201": tf.keras.applications.DenseNet201,
"nasnetlarge": tf.keras.applications.NASNetLarge,
"nasnetmobile": tf.keras.applications.NASNetMobile,
}
def run_keras_model_benchmark(_):
"""Run the benchmark on keras model."""
# Ensure a valid model name was supplied via command line argument
if FLAGS.model not in MODELS.keys():
raise AssertionError("The --model command line argument should "
"be a key in the `MODELS` dictionary.")
# Check if eager execution is enabled
if FLAGS.eager:
tf.logging.info("Eager execution is enabled...")
tf.enable_eager_execution()
# Load the model
tf.logging.info("Benchmark on {} model...".format(FLAGS.model))
keras_model = MODELS[FLAGS.model]
# Get dataset
dataset_name = "ImageNet"
if FLAGS.use_synthetic_data:
tf.logging.info("Using synthetic dataset...")
dataset_name += "_Synthetic"
train_dataset = dataset.generate_synthetic_input_dataset(
FLAGS.model, FLAGS.batch_size)
val_dataset = dataset.generate_synthetic_input_dataset(
FLAGS.model, FLAGS.batch_size)
model = keras_model(weights=None)
else:
tf.logging.info("Using CIFAR-10 dataset...")
dataset_name = "CIFAR-10"
ds = dataset.Cifar10Dataset(FLAGS.batch_size)
train_dataset = ds.train_dataset
val_dataset = ds.test_dataset
model = keras_model(
weights=None, input_shape=ds.input_shape, classes=ds.num_classes)
num_gpus = flags_core.get_num_gpus(FLAGS)
distribution = None
# Use distribution strategy
if FLAGS.dist_strat:
distribution = distribution_utils.get_distribution_strategy(
num_gpus=num_gpus)
elif num_gpus > 1:
# Run with multi_gpu_model
# If eager execution is enabled, only one GPU is utilized even if multiple
# GPUs are provided.
if FLAGS.eager:
tf.logging.warning(
"{} GPUs are provided, but only one GPU is utilized as "
"eager execution is enabled.".format(num_gpus))
model = tf.keras.utils.multi_gpu_model(model, gpus=num_gpus)
# Adam optimizer and some other optimizers doesn't work well with
# distribution strategy (b/113076709)
# Use GradientDescentOptimizer here
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
model.compile(loss="categorical_crossentropy",
optimizer=optimizer,
metrics=["accuracy"],
distribute=distribution)
# Create benchmark logger for benchmark logging
run_params = {
"batch_size": FLAGS.batch_size,
"synthetic_data": FLAGS.use_synthetic_data,
"train_epochs": FLAGS.train_epochs,
"num_train_images": FLAGS.num_train_images,
"num_eval_images": FLAGS.num_eval_images,
}
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info(
model_name=FLAGS.model,
dataset_name=dataset_name,
run_params=run_params,
test_id=FLAGS.benchmark_test_id)
# Create callbacks that log metric values about the training and evaluation
callbacks = model_callbacks.get_model_callbacks(
FLAGS.callbacks,
batch_size=FLAGS.batch_size,
metric_logger=benchmark_logger)
# Train and evaluate the model
history = model.fit(
train_dataset,
epochs=FLAGS.train_epochs,
callbacks=callbacks,
validation_data=val_dataset,
steps_per_epoch=int(np.ceil(FLAGS.num_train_images / FLAGS.batch_size)),
validation_steps=int(np.ceil(FLAGS.num_eval_images / FLAGS.batch_size))
)
tf.logging.info("Logging the evaluation results...")
for epoch in range(FLAGS.train_epochs):
eval_results = {
"accuracy": history.history["val_acc"][epoch],
"loss": history.history["val_loss"][epoch],
tf.GraphKeys.GLOBAL_STEP: (epoch + 1) * np.ceil(
FLAGS.num_eval_images/FLAGS.batch_size)
}
benchmark_logger.log_evaluation_result(eval_results)
# Clear the session explicitly to avoid session delete error
tf.keras.backend.clear_session()
def define_keras_benchmark_flags():
"""Add flags for keras built-in application models."""
flags_core.define_base(hooks=False)
flags_core.define_performance()
flags_core.define_image()
flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core)
flags_core.set_defaults(
data_format="channels_last",
use_synthetic_data=True,
batch_size=32,
train_epochs=2)
flags.DEFINE_enum(
name="model", default=None,
enum_values=MODELS.keys(), case_sensitive=False,
help=flags_core.help_wrap(
"Model to be benchmarked."))
flags.DEFINE_integer(
name="num_train_images", default=1000,
help=flags_core.help_wrap(
"The number of synthetic images for training. The default value is "
"1000."))
flags.DEFINE_integer(
name="num_eval_images", default=50,
help=flags_core.help_wrap(
"The number of synthetic images for evaluation. The default value is "
"50."))
flags.DEFINE_boolean(
name="eager", default=False, help=flags_core.help_wrap(
"To enable eager execution. Note that if eager execution is enabled, "
"only one GPU is utilized even if multiple GPUs are provided and "
"multi_gpu_model is used."))
flags.DEFINE_boolean(
name="dist_strat", default=False, help=flags_core.help_wrap(
"To enable distribution strategy for model training and evaluation. "
"Number of GPUs used for distribution strategy can be set by the "
"argument --num_gpus."))
flags.DEFINE_list(
name="callbacks",
default=["ExamplesPerSecondCallback", "LoggingMetricCallback"],
help=flags_core.help_wrap(
"A list of (case insensitive) strings to specify the names of "
"callbacks. For example: `--callbacks ExamplesPerSecondCallback,"
"LoggingMetricCallback`"))
@flags.multi_flags_validator(
["eager", "dist_strat"],
message="Both --eager and --dist_strat were set. Only one can be "
"defined, as DistributionStrategy is not supported in Eager "
"execution currently.")
# pylint: disable=unused-variable
def _check_eager_dist_strat(flag_dict):
return not(flag_dict["eager"] and flag_dict["dist_strat"])
def main(_):
with logger.benchmark_context(FLAGS):
run_keras_model_benchmark(FLAGS)
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
define_keras_benchmark_flags()
FLAGS = flags.FLAGS
absl_app.run(main)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Prepare dataset for keras model benchmark."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from official.utils.misc import model_helpers # pylint: disable=g-bad-import-order
# Default values for dataset.
_NUM_CHANNELS = 3
_NUM_CLASSES = 1000
def _get_default_image_size(model):
"""Provide default image size for each model."""
image_size = (224, 224)
if model in ["inceptionv3", "xception", "inceptionresnetv2"]:
image_size = (299, 299)
elif model in ["nasnetlarge"]:
image_size = (331, 331)
return image_size
def generate_synthetic_input_dataset(model, batch_size):
"""Generate synthetic dataset."""
image_size = _get_default_image_size(model)
image_shape = (batch_size,) + image_size + (_NUM_CHANNELS,)
label_shape = (batch_size, _NUM_CLASSES)
dataset = model_helpers.generate_synthetic_data(
input_shape=tf.TensorShape(image_shape),
label_shape=tf.TensorShape(label_shape),
)
return dataset
class Cifar10Dataset(object):
"""CIFAR10 dataset, including train and test set.
Each sample consists of a 32x32 color image, and label is from 10 classes.
"""
def __init__(self, batch_size):
"""Initializes train/test datasets.
Args:
batch_size: int, the number of batch size.
"""
self.input_shape = (32, 32, 3)
self.num_classes = 10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64)
y_train = tf.keras.utils.to_categorical(y_train, self.num_classes)
y_test = tf.keras.utils.to_categorical(y_test, self.num_classes)
self.train_dataset = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(2000).batch(batch_size).repeat()
self.test_dataset = tf.data.Dataset.from_tensor_slices(
(x_test, y_test)).shuffle(2000).batch(batch_size).repeat()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Callbacks for Keras built-in application models.
Note that, in the callbacks, the global_step is initialized in the __init__ of
each callback rather than on_train_begin. As on_train_begin gets called in
the fit_loop, and it will be reset with each call to fit(). To keep the
global_step persistent across all training sessions, it should be initialized in
the __init__.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logs import logger
# Metrics to log after each batch and epoch
_PER_BATCH_METRICS = {
"loss": "train_loss",
"acc": "train_accuracy",
}
_PER_EPOCH_METRICS = {
"loss": "train_loss",
"acc": "train_accuracy",
"val_loss": "loss",
"val_acc": "accuracy"
}
class ExamplesPerSecondCallback(tf.keras.callbacks.Callback):
"""ExamplesPerSecond callback.
This callback records the average_examples_per_sec and
current_examples_per_sec during training.
"""
def __init__(self, batch_size, every_n_steps=1, metric_logger=None):
self._batch_size = batch_size
self._every_n_steps = every_n_steps
self._logger = metric_logger or logger.BaseBenchmarkLogger()
self._global_step = 0 # Initialize it in __init__
super(ExamplesPerSecondCallback, self).__init__()
def on_train_begin(self, logs=None):
self._train_start_time = time.time()
self._last_recorded_time = time.time()
def on_batch_end(self, batch, logs=None):
"""Log the examples_per_sec metric every_n_steps."""
self._global_step += 1
current_time = time.time()
if self._global_step % self._every_n_steps == 0:
average_examples_per_sec = self._batch_size * (
self._global_step / (current_time - self._train_start_time))
self._logger.log_metric(
"average_examples_per_sec", average_examples_per_sec,
global_step=self._global_step)
current_examples_per_sec = self._batch_size * (
self._every_n_steps / (current_time - self._last_recorded_time))
self._logger.log_metric(
"current_examples_per_sec", current_examples_per_sec,
global_step=self._global_step)
self._last_recorded_time = current_time # Update last_recorded_time
class LoggingMetricCallback(tf.keras.callbacks.Callback):
"""LoggingMetric callback.
Log the predefined _PER_BATCH_METRICS after each batch, and log the predefined
_PER_EPOCH_METRICS after each epoch.
"""
def __init__(self, metric_logger=None):
self._logger = metric_logger or logger.BaseBenchmarkLogger()
self._per_batch_metrics = _PER_BATCH_METRICS
self._per_epoch_metrics = _PER_EPOCH_METRICS
self._global_step = 0 # Initialize it in __init__
super(LoggingMetricCallback, self).__init__()
def on_batch_end(self, batch, logs=None):
"""Log metrics after each batch."""
self._global_step += 1
for metric in _PER_BATCH_METRICS:
self._logger.log_metric(
_PER_BATCH_METRICS[metric],
logs.get(metric),
global_step=self._global_step)
def on_epoch_end(self, epoch, logs=None):
"""Log metrics after each epoch."""
for metric in _PER_EPOCH_METRICS:
self._logger.log_metric(
_PER_EPOCH_METRICS[metric],
logs.get(metric),
global_step=self._global_step)
def get_model_callbacks(name_list, **kwargs):
"""Factory for getting a list of TensorFlow hooks for training by name.
Args:
name_list: a list of strings to name desired callback classes. Allowed:
ExamplesPerSecondCallback, LoggingMetricCallback, which are defined
as keys in CALLBACKS.
**kwargs: a dictionary of arguments to the callbacks.
Returns:
list of instantiated callbacks, ready to be used in a classifier.train call.
Raises:
ValueError: if an unrecognized name is passed.
"""
if not name_list:
return []
callbacks = []
for name in name_list:
callback_name = CALLBACKS.get(name.strip().lower())
if callback_name is None:
raise ValueError(
"Unrecognized training callback requested: {}".format(name))
else:
callbacks.append(callback_name(**kwargs))
return callbacks
def get_examples_per_second_callback(
every_n_steps=1, batch_size=32, metric_logger=None, **kwargs): # pylint: disable=unused-argument
"""Function to get ExamplesPerSecondCallback."""
return ExamplesPerSecondCallback(
batch_size=batch_size, every_n_steps=every_n_steps,
metric_logger=metric_logger or logger.get_benchmark_logger())
def get_logging_metric_callback(metric_logger=None, **kwargs): # pylint: disable=unused-argument
"""Function to get LoggingMetricCallback."""
return LoggingMetricCallback(
metric_logger=metric_logger or logger.get_benchmark_logger())
# A dictionary to map the callback name and its corresponding function
CALLBACKS = {
"examplespersecondcallback": get_examples_per_second_callback,
"loggingmetriccallback": get_logging_metric_callback,
}
# MNIST in TensorFlow
This directory builds a convolutional neural net to classify the [MNIST
dataset](http://yann.lecun.com/exdb/mnist/) using the
[tf.data](https://www.tensorflow.org/api_docs/python/tf/data),
[tf.estimator.Estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator),
and
[tf.layers](https://www.tensorflow.org/api_docs/python/tf/layers)
APIs.
## Setup
To begin, you'll simply need the latest version of TensorFlow installed.
First make sure you've [added the models folder to your Python path](/official/#running-the-models); otherwise you may encounter an error like `ImportError: No module named official.mnist`.
Then to train the model, run the following:
```
python mnist.py
```
The model will begin training and will automatically evaluate itself on the
validation data.
Illustrative unit tests and benchmarks can be run with:
```
python mnist_test.py
python mnist_test.py --benchmarks=.
```
## Exporting the model
You can export the model into Tensorflow [SavedModel](https://www.tensorflow.org/guide/saved_model) format by using the argument `--export_dir`:
```
python mnist.py --export_dir /tmp/mnist_saved_model
```
The SavedModel will be saved in a timestamped directory under `/tmp/mnist_saved_model/` (e.g. `/tmp/mnist_saved_model/1513630966/`).
**Getting predictions with SavedModel**
Use [`saved_model_cli`](https://www.tensorflow.org/guide/saved_model#cli_to_inspect_and_execute_savedmodel) to inspect and execute the SavedModel.
```
saved_model_cli run --dir /tmp/mnist_saved_model/TIMESTAMP --tag_set serve --signature_def classify --inputs image=examples.npy
```
`examples.npy` contains the data from `example5.png` and `example3.png` in a numpy array, in that order. The array values are normalized to values between 0 and 1.
The output should look similar to below:
```
Result for output key classes:
[5 3]
Result for output key probabilities:
[[ 1.53558474e-07 1.95694142e-13 1.31193523e-09 5.47467265e-03
5.85711526e-22 9.94520664e-01 3.48423509e-06 2.65365645e-17
9.78631419e-07 3.15522470e-08]
[ 1.22413359e-04 5.87615965e-08 1.72251271e-06 9.39960718e-01
3.30306928e-11 2.87386645e-02 2.82353517e-02 8.21146413e-18
2.52568233e-03 4.15460236e-04]]
```
## Experimental: Eager Execution
[Eager execution](https://research.googleblog.com/2017/10/eager-execution-imperative-define-by.html)
(an preview feature in TensorFlow 1.5) is an imperative interface to TensorFlow.
The exact same model defined in `mnist.py` can be trained without creating a
TensorFlow graph using:
```
python mnist_eager.py
```
## Experimental: TPU Acceleration
`mnist.py` (and `mnist_eager.py`) demonstrate training a neural network to
classify digits on CPUs and GPUs. `mnist_tpu.py` can be used to train the
same model using TPUs for hardware acceleration. More information in
the [tensorflow/tpu](https://github.com/tensorflow/tpu) repository.
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""tf.data.Dataset interface to the MNIST dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
import shutil
import tempfile
import numpy as np
from six.moves import urllib
import tensorflow as tf
def read32(bytestream):
"""Read 4 bytes from bytestream as an unsigned 32-bit integer."""
dt = np.dtype(np.uint32).newbyteorder('>')
return np.frombuffer(bytestream.read(4), dtype=dt)[0]
def check_image_file_header(filename):
"""Validate that filename corresponds to images for the MNIST dataset."""
with tf.gfile.Open(filename, 'rb') as f:
magic = read32(f)
read32(f) # num_images, unused
rows = read32(f)
cols = read32(f)
if magic != 2051:
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
f.name))
if rows != 28 or cols != 28:
raise ValueError(
'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %
(f.name, rows, cols))
def check_labels_file_header(filename):
"""Validate that filename corresponds to labels for the MNIST dataset."""
with tf.gfile.Open(filename, 'rb') as f:
magic = read32(f)
read32(f) # num_items, unused
if magic != 2049:
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
f.name))
def download(directory, filename):
"""Download (and unzip) a file from the MNIST dataset if not already done."""
filepath = os.path.join(directory, filename)
if tf.gfile.Exists(filepath):
return filepath
if not tf.gfile.Exists(directory):
tf.gfile.MakeDirs(directory)
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
_, zipped_filepath = tempfile.mkstemp(suffix='.gz')
print('Downloading %s to %s' % (url, zipped_filepath))
urllib.request.urlretrieve(url, zipped_filepath)
with gzip.open(zipped_filepath, 'rb') as f_in, \
tf.gfile.Open(filepath, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
os.remove(zipped_filepath)
return filepath
def dataset(directory, images_file, labels_file):
"""Download and parse MNIST dataset."""
images_file = download(directory, images_file)
labels_file = download(directory, labels_file)
check_image_file_header(images_file)
check_labels_file_header(labels_file)
def decode_image(image):
# Normalize from [0, 255] to [0.0, 1.0]
image = tf.decode_raw(image, tf.uint8)
image = tf.cast(image, tf.float32)
image = tf.reshape(image, [784])
return image / 255.0
def decode_label(label):
label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]
label = tf.reshape(label, []) # label is a scalar
return tf.to_int32(label)
images = tf.data.FixedLengthRecordDataset(
images_file, 28 * 28, header_bytes=16).map(decode_image)
labels = tf.data.FixedLengthRecordDataset(
labels_file, 1, header_bytes=8).map(decode_label)
return tf.data.Dataset.zip((images, labels))
def train(directory):
"""tf.data.Dataset object for MNIST training data."""
return dataset(directory, 'train-images-idx3-ubyte',
'train-labels-idx1-ubyte')
def test(directory):
"""tf.data.Dataset object for MNIST test data."""
return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.mnist import dataset
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers
LEARNING_RATE = 1e-4
def create_model(data_format):
"""Model to recognize digits in the MNIST dataset.
Network structure is equivalent to:
https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
and
https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py
But uses the tf.keras API.
Args:
data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
typically faster on GPUs while 'channels_last' is typically faster on
CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats
Returns:
A tf.keras.Model.
"""
if data_format == 'channels_first':
input_shape = [1, 28, 28]
else:
assert data_format == 'channels_last'
input_shape = [28, 28, 1]
l = tf.keras.layers
max_pool = l.MaxPooling2D(
(2, 2), (2, 2), padding='same', data_format=data_format)
# The model consists of a sequential chain of layers, so tf.keras.Sequential
# (a subclass of tf.keras.Model) makes for a compact description.
return tf.keras.Sequential(
[
l.Reshape(
target_shape=input_shape,
input_shape=(28 * 28,)),
l.Conv2D(
32,
5,
padding='same',
data_format=data_format,
activation=tf.nn.relu),
max_pool,
l.Conv2D(
64,
5,
padding='same',
data_format=data_format,
activation=tf.nn.relu),
max_pool,
l.Flatten(),
l.Dense(1024, activation=tf.nn.relu),
l.Dropout(0.4),
l.Dense(10)
])
def define_mnist_flags():
flags_core.define_base()
flags_core.define_performance(num_parallel_calls=False)
flags_core.define_image()
flags.adopt_module_key_flags(flags_core)
flags_core.set_defaults(data_dir='/tmp/mnist_data',
model_dir='/tmp/mnist_model',
batch_size=100,
train_epochs=40)
def model_fn(features, labels, mode, params):
"""The model_fn argument for creating an Estimator."""
model = create_model(params['data_format'])
image = features
if isinstance(image, dict):
image = features['image']
if mode == tf.estimator.ModeKeys.PREDICT:
logits = model(image, training=False)
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits),
}
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
'classify': tf.estimator.export.PredictOutput(predictions)
})
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
logits = model(image, training=True)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
accuracy = tf.metrics.accuracy(
labels=labels, predictions=tf.argmax(logits, axis=1))
# Name tensors to be logged with LoggingTensorHook.
tf.identity(LEARNING_RATE, 'learning_rate')
tf.identity(loss, 'cross_entropy')
tf.identity(accuracy[1], name='train_accuracy')
# Save accuracy scalar to Tensorboard output.
tf.summary.scalar('train_accuracy', accuracy[1])
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=loss,
train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
if mode == tf.estimator.ModeKeys.EVAL:
logits = model(image, training=False)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL,
loss=loss,
eval_metric_ops={
'accuracy':
tf.metrics.accuracy(
labels=labels, predictions=tf.argmax(logits, axis=1)),
})
def run_mnist(flags_obj):
"""Run MNIST training and eval loop.
Args:
flags_obj: An object containing parsed flag values.
"""
model_helpers.apply_clean(flags_obj)
model_function = model_fn
session_config = tf.ConfigProto(
inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
allow_soft_placement=True)
distribution_strategy = distribution_utils.get_distribution_strategy(
flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)
run_config = tf.estimator.RunConfig(
train_distribute=distribution_strategy, session_config=session_config)
data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
mnist_classifier = tf.estimator.Estimator(
model_fn=model_function,
model_dir=flags_obj.model_dir,
config=run_config,
params={
'data_format': data_format,
})
# Set up training and evaluation input functions.
def train_input_fn():
"""Prepare data for training."""
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch.
ds = dataset.train(flags_obj.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)
# Iterate through the dataset a set number (`epochs_between_evals`) of times
# during each training session.
ds = ds.repeat(flags_obj.epochs_between_evals)
return ds
def eval_input_fn():
return dataset.test(flags_obj.data_dir).batch(
flags_obj.batch_size).make_one_shot_iterator().get_next()
# Set up hook that outputs training logs every 100 steps.
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks, model_dir=flags_obj.model_dir,
batch_size=flags_obj.batch_size)
# Train and evaluate model.
for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results)
if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
eval_results['accuracy']):
break
# Export the model
if flags_obj.export_dir is not None:
image = tf.placeholder(tf.float32, [None, 28, 28])
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'image': image,
})
mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn,
strip_default_attrs=True)
def main(_):
run_mnist(flags.FLAGS)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
define_mnist_flags()
absl_app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment