Commit f9ac9618 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Remove this r1 folder from the master branch in June, 2020.

PiperOrigin-RevId: 317772122
parent d4f5c193
![No Maintenance Intended](https://img.shields.io/badge/No%20Maintenance%20Intended-%E2%9C%95-red.svg)
![TensorFlow Requirement: 1.x](https://img.shields.io/badge/TensorFlow%20Requirement-1.x-brightgreen)
![TensorFlow 2 Not Supported](https://img.shields.io/badge/TensorFlow%202%20Not%20Supported-%E2%9C%95-red.svg)
# Legacy Models
The **r1** folder contains legacy model implementations developed
using TensorFlow 1.x.
**Note: We will remove this r1 folder from the master branch in June, 2020.**
After removal, you will still be able to access legacy models
in the previous releases.
(e.g., [v2.1.0](https://github.com/tensorflow/models/releases/tag/v2.1.0))
| Model | Description | Reference |
| ----- | ----------- | --------- |
| [Gradient Boosted Trees](boosted_trees) | A gradient boosted trees model to classify higgs boson process from HIGGS dataset | [Link](https://en.wikipedia.org/wiki/Gradient_boosting) |
| [MNIST](mnist) | A basic model to classify digits from the MNIST dataset | [Link](http://yann.lecun.com/exdb/mnist/) |
| [NCF](ncf) | NCF Estimator implementation | [arXiv:1708.05031](https://arxiv.org/abs/1708.05031) |
| [ResNet](resnet) | A deep residual network for image recognition | [arXiv:1512.03385](https://arxiv.org/abs/1512.03385) |
| [Transformer](transformer) | A transformer model to translate the WMT English to German dataset | [arXiv:1706.03762](https://arxiv.org/abs/1706.03762) |
| [Wide & Deep Learning](wide_deep) | A model that combines a wide linear model and deep neural network for recommender systems | [arXiv:1606.07792](https://arxiv.org/abs/1606.07792) |
![No Maintenance Intended](https://img.shields.io/badge/No%20Maintenance%20Intended-%E2%9C%95-red.svg)
![TensorFlow Requirement: 1.x](https://img.shields.io/badge/TensorFlow%20Requirement-1.x-brightgreen)
![TensorFlow 2 Not Supported](https://img.shields.io/badge/TensorFlow%202%20Not%20Supported-%E2%9C%95-red.svg)
# Classifying Higgs boson processes in the HIGGS Data Set
## Overview
The [HIGGS Data Set](https://archive.ics.uci.edu/ml/datasets/HIGGS) contains 11 million samples with 28 features, and is for the classification problem to distinguish between a signal process which produces Higgs bosons and a background process which does not.
We use Gradient Boosted Trees algorithm to distinguish the two classes.
---
The code sample uses the high level `tf.estimator.Estimator` and `tf.data.Dataset`. These APIs are great for fast iteration and quickly adapting models to your own datasets without major code overhauls. It allows you to move from single-worker training to distributed training, and makes it easy to export model binaries for prediction. Here, for further simplicity and faster execution, we use a utility function `tf.contrib.estimator.boosted_trees_classifier_train_in_memory`. This utility function is especially effective when the input is provided as in-memory data sets like numpy arrays.
An input function for the `Estimator` typically uses `tf.data.Dataset` API, which can handle various data control like streaming, batching, transform and shuffling. However `boosted_trees_classifier_train_in_memory()` utility function requires that the entire data is provided as a single batch (i.e. without using `batch()` API). Thus in this practice, simply `Dataset.from_tensors()` is used to convert numpy arrays into structured tensors, and `Dataset.zip()` is used to put features and label together.
For further references of `Dataset`, [Read more here](https://www.tensorflow.org/guide/datasets).
## Running the code
First make sure you've [added the models folder to your Python path](/official/#running-the-models); otherwise you may encounter an error like `ImportError: No module named official.boosted_trees`.
### Setup
The [HIGGS Data Set](https://archive.ics.uci.edu/ml/datasets/HIGGS) that this sample uses for training is hosted by the [UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/). We have provided a script that downloads and cleans the necessary files.
```
python data_download.py
```
This will download a file and store the processed file under the directory designated by `--data_dir` (defaults to `/tmp/higgs_data/`). To change the target directory, set the `--data_dir` flag. The directory could be network storages that Tensorflow supports (like Google Cloud Storage, `gs://<bucket>/<path>/`).
The file downloaded to the local temporary folder is about 2.8 GB, and the processed file is about 0.8 GB, so there should be enough storage to handle them.
### Training
This example uses about 3 GB of RAM during training.
You can run the code locally as follows:
```
python train_higgs.py
```
The model is by default saved to `/tmp/higgs_model`, which can be changed using the `--model_dir` flag.
Note that the model_dir is cleaned up before every time training starts.
Model parameters can be adjusted by flags, like `--n_trees`, `--max_depth`, `--learning_rate` and so on. Check out the code for details.
The final accuracy will be around 74% and loss will be around 0.516 over the eval set, when trained with the default parameters.
By default, the first 1 million examples among 11 millions are used for training, and the last 1 million examples are used for evaluation.
The training/evaluation data can be selected as index ranges by flags `--train_start`, `--train_count`, `--eval_start`, `--eval_count`, etc.
### TensorBoard
Run TensorBoard to inspect the details about the graph and training progression.
```
tensorboard --logdir=/tmp/higgs_model # set logdir as --model_dir set during training.
```
## Inference with SavedModel
You can export the model into Tensorflow [SavedModel](https://www.tensorflow.org/guide/saved_model) format by using the argument `--export_dir`:
```
python train_higgs.py --export_dir /tmp/higgs_boosted_trees_saved_model
```
After the model finishes training, use [`saved_model_cli`](https://www.tensorflow.org/guide/saved_model#cli_to_inspect_and_execute_savedmodel) to inspect and execute the SavedModel.
Try the following commands to inspect the SavedModel:
**Replace `${TIMESTAMP}` with the folder produced (e.g. 1524249124)**
```
# List possible tag_sets. Only one metagraph is saved, so there will be one option.
saved_model_cli show --dir /tmp/higgs_boosted_trees_saved_model/${TIMESTAMP}/
# Show SignatureDefs for tag_set=serve. SignatureDefs define the outputs to show.
saved_model_cli show --dir /tmp/higgs_boosted_trees_saved_model/${TIMESTAMP}/ \
--tag_set serve --all
```
### Inference
Let's use the model to predict the income group of two examples.
Note that this model exports SavedModel with the custom parsing module that accepts csv lines as features. (Each line is an example with 28 columns; be careful to not add a label column, unlike in the training data.)
```
saved_model_cli run --dir /tmp/boosted_trees_higgs_saved_model/${TIMESTAMP}/ \
--tag_set serve --signature_def="predict" \
--input_exprs='inputs=["0.869293,-0.635082,0.225690,0.327470,-0.689993,0.754202,-0.248573,-1.092064,0.0,1.374992,-0.653674,0.930349,1.107436,1.138904,-1.578198,-1.046985,0.0,0.657930,-0.010455,-0.045767,3.101961,1.353760,0.979563,0.978076,0.920005,0.721657,0.988751,0.876678", "1.595839,-0.607811,0.007075,1.818450,-0.111906,0.847550,-0.566437,1.581239,2.173076,0.755421,0.643110,1.426367,0.0,0.921661,-1.190432,-1.615589,0.0,0.651114,-0.654227,-1.274345,3.101961,0.823761,0.938191,0.971758,0.789176,0.430553,0.961357,0.957818"]'
```
This will print out the predicted classes and class probabilities. Something like:
```
Result for output key class_ids:
[[1]
[0]]
Result for output key classes:
[['1']
['0']]
Result for output key logistic:
[[0.6440273 ]
[0.10902369]]
Result for output key logits:
[[ 0.59288704]
[-2.1007526 ]]
Result for output key probabilities:
[[0.3559727 0.6440273]
[0.8909763 0.1090237]]
```
Please note that "predict" signature_def gives out different (more detailed) results than "classification" or "serving_default".
## Additional Links
If you are interested in distributed training, take a look at [Distributed TensorFlow](https://www.tensorflow.org/deploy/distributed).
You can also [train models on Cloud ML Engine](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction), which provides [hyperparameter tuning](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction#hyperparameter_tuning) to maximize your model's results and enables [deploying your model for prediction](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction#deploy_a_model_to_support_prediction).
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Downloads the UCI HIGGS Dataset and prepares train data.
The details on the dataset are in https://archive.ics.uci.edu/ml/datasets/HIGGS
It takes a while as it needs to download 2.8 GB over the network, process, then
store it into the specified location as a compressed numpy file.
Usage:
$ python data_download.py --data_dir=/tmp/higgs_data
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
import tempfile
# pylint: disable=g-bad-import-order
import numpy as np
import pandas as pd
from six.moves import urllib
from absl import app as absl_app
from absl import flags
import tensorflow as tf
from official.utils.flags import core as flags_core
URL_ROOT = "https://archive.ics.uci.edu/ml/machine-learning-databases/00280"
INPUT_FILE = "HIGGS.csv.gz"
NPZ_FILE = "HIGGS.csv.gz.npz" # numpy compressed file to contain "data" array.
def _download_higgs_data_and_save_npz(data_dir):
"""Download higgs data and store as a numpy compressed file."""
input_url = URL_ROOT + "/" + INPUT_FILE
np_filename = os.path.join(data_dir, NPZ_FILE)
if tf.gfile.Exists(np_filename):
raise ValueError("data_dir already has the processed data file: {}".format(
np_filename))
if not tf.gfile.Exists(data_dir):
tf.gfile.MkDir(data_dir)
# 2.8 GB to download.
try:
tf.logging.info("Data downloading...")
temp_filename, _ = urllib.request.urlretrieve(input_url)
# Reading and parsing 11 million csv lines takes 2~3 minutes.
tf.logging.info("Data processing... taking multiple minutes...")
with gzip.open(temp_filename, "rb") as csv_file:
data = pd.read_csv(
csv_file,
dtype=np.float32,
names=["c%02d" % i for i in range(29)] # label + 28 features.
).as_matrix()
finally:
tf.gfile.Remove(temp_filename)
# Writing to temporary location then copy to the data_dir (0.8 GB).
f = tempfile.NamedTemporaryFile()
np.savez_compressed(f, data=data)
tf.gfile.Copy(f.name, np_filename)
tf.logging.info("Data saved to: {}".format(np_filename))
def main(unused_argv):
if not tf.gfile.Exists(FLAGS.data_dir):
tf.gfile.MkDir(FLAGS.data_dir)
_download_higgs_data_and_save_npz(FLAGS.data_dir)
def define_data_download_flags():
"""Add flags specifying data download arguments."""
flags.DEFINE_string(
name="data_dir", default="/tmp/higgs_data",
help=flags_core.help_wrap(
"Directory to download higgs dataset and store training/eval data."))
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
define_data_download_flags()
FLAGS = flags.FLAGS
absl_app.run(main)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""A script that builds boosted trees over higgs data.
If you haven't, please run data_download.py beforehand to prepare the data.
For some more details on this example, please refer to README.md as well.
Note that the model_dir is cleaned up before starting the training.
Usage:
$ python train_higgs.py --n_trees=100 --max_depth=6 --learning_rate=0.1 \
--model_dir=/tmp/higgs_model
Note that BoostedTreesClassifier is available since Tensorflow 1.8.0.
So you need to install recent enough version of Tensorflow to use this example.
The training data is by default the first million examples out of 11M examples,
and eval data is by default the last million examples.
They are controlled by --train_start, --train_count, --eval_start, --eval_count.
e.g. to train over the first 10 million examples instead of 1 million:
$ python train_higgs.py --n_trees=100 --max_depth=6 --learning_rate=0.1 \
--model_dir=/tmp/higgs_model --train_count=10000000
Training history and metrics can be inspected using tensorboard.
Set --logdir as the --model_dir set by flag when training
(or the default /tmp/higgs_model).
$ tensorboard --logdir=/tmp/higgs_model
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app as absl_app
from absl import flags
import numpy as np
import tensorflow.compat.v1 as tf
from official.r1.utils.logs import logger
from official.utils.flags import core as flags_core
from official.utils.flags._conventions import help_wrap
NPZ_FILE = "HIGGS.csv.gz.npz" # numpy compressed file containing "data" array
def read_higgs_data(data_dir, train_start, train_count, eval_start, eval_count):
"""Reads higgs data from csv and returns train and eval data.
Args:
data_dir: A string, the directory of higgs dataset.
train_start: An integer, the start index of train examples within the data.
train_count: An integer, the number of train examples within the data.
eval_start: An integer, the start index of eval examples within the data.
eval_count: An integer, the number of eval examples within the data.
Returns:
Numpy array of train data and eval data.
"""
npz_filename = os.path.join(data_dir, NPZ_FILE)
try:
# gfile allows numpy to read data from network data sources as well.
with tf.gfile.Open(npz_filename, "rb") as npz_file:
with np.load(npz_file) as npz:
data = npz["data"]
except tf.errors.NotFoundError as e:
raise RuntimeError(
"Error loading data; use data_download.py to prepare the data.\n{}: {}"
.format(type(e).__name__, e))
return (data[train_start:train_start+train_count],
data[eval_start:eval_start+eval_count])
# This showcases how to make input_fn when the input data is available in the
# form of numpy arrays.
def make_inputs_from_np_arrays(features_np, label_np):
"""Makes and returns input_fn and feature_columns from numpy arrays.
The generated input_fn will return tf.data.Dataset of feature dictionary and a
label, and feature_columns will consist of the list of
tf.feature_column.BucketizedColumn.
Note, for in-memory training, tf.data.Dataset should contain the whole data
as a single tensor. Don't use batch.
Args:
features_np: A numpy ndarray (shape=[batch_size, num_features]) for
float32 features.
label_np: A numpy ndarray (shape=[batch_size, 1]) for labels.
Returns:
input_fn: A function returning a Dataset of feature dict and label.
feature_names: A list of feature names.
feature_column: A list of tf.feature_column.BucketizedColumn.
"""
num_features = features_np.shape[1]
features_np_list = np.split(features_np, num_features, axis=1)
# 1-based feature names.
feature_names = ["feature_%02d" % (i + 1) for i in range(num_features)]
# Create source feature_columns and bucketized_columns.
def get_bucket_boundaries(feature):
"""Returns bucket boundaries for feature by percentiles."""
return np.unique(np.percentile(feature, range(0, 100))).tolist()
source_columns = [
tf.feature_column.numeric_column(
feature_name, dtype=tf.float32,
# Although higgs data have no missing values, in general, default
# could be set as 0 or some reasonable value for missing values.
default_value=0.0)
for feature_name in feature_names
]
bucketized_columns = [
tf.feature_column.bucketized_column(
source_columns[i],
boundaries=get_bucket_boundaries(features_np_list[i]))
for i in range(num_features)
]
# Make an input_fn that extracts source features.
def input_fn():
"""Returns features as a dictionary of numpy arrays, and a label."""
features = {
feature_name: tf.constant(features_np_list[i])
for i, feature_name in enumerate(feature_names)
}
return tf.data.Dataset.zip((tf.data.Dataset.from_tensors(features),
tf.data.Dataset.from_tensors(label_np),))
return input_fn, feature_names, bucketized_columns
def make_eval_inputs_from_np_arrays(features_np, label_np):
"""Makes eval input as streaming batches."""
num_features = features_np.shape[1]
features_np_list = np.split(features_np, num_features, axis=1)
# 1-based feature names.
feature_names = ["feature_%02d" % (i + 1) for i in range(num_features)]
def input_fn():
features = {
feature_name: tf.constant(features_np_list[i])
for i, feature_name in enumerate(feature_names)
}
return tf.data.Dataset.zip((
tf.data.Dataset.from_tensor_slices(features),
tf.data.Dataset.from_tensor_slices(label_np),)).batch(1000)
return input_fn
def _make_csv_serving_input_receiver_fn(column_names, column_defaults):
"""Returns serving_input_receiver_fn for csv.
The input arguments are relevant to `tf.decode_csv()`.
Args:
column_names: a list of column names in the order within input csv.
column_defaults: a list of default values with the same size of
column_names. Each entity must be either a list of one scalar, or an
empty list to denote the corresponding column is required.
e.g. [[""], [2.5], []] indicates the third column is required while
the first column must be string and the second must be float/double.
Returns:
a serving_input_receiver_fn that handles csv for serving.
"""
def serving_input_receiver_fn():
csv = tf.placeholder(dtype=tf.string, shape=[None], name="csv")
features = dict(zip(column_names, tf.decode_csv(csv, column_defaults)))
receiver_tensors = {"inputs": csv}
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
return serving_input_receiver_fn
def train_boosted_trees(flags_obj):
"""Train boosted_trees estimator on HIGGS data.
Args:
flags_obj: An object containing parsed flag values.
"""
# Clean up the model directory if present.
if tf.gfile.Exists(flags_obj.model_dir):
tf.gfile.DeleteRecursively(flags_obj.model_dir)
tf.logging.info("## Data loading...")
train_data, eval_data = read_higgs_data(
flags_obj.data_dir, flags_obj.train_start, flags_obj.train_count,
flags_obj.eval_start, flags_obj.eval_count)
tf.logging.info("## Data loaded; train: {}{}, eval: {}{}".format(
train_data.dtype, train_data.shape, eval_data.dtype, eval_data.shape))
# Data consists of one label column followed by 28 feature columns.
train_input_fn, feature_names, feature_columns = make_inputs_from_np_arrays(
features_np=train_data[:, 1:], label_np=train_data[:, 0:1])
eval_input_fn = make_eval_inputs_from_np_arrays(
features_np=eval_data[:, 1:], label_np=eval_data[:, 0:1])
tf.logging.info("## Features prepared. Training starts...")
# Create benchmark logger to log info about the training and metric values
run_params = {
"train_start": flags_obj.train_start,
"train_count": flags_obj.train_count,
"eval_start": flags_obj.eval_start,
"eval_count": flags_obj.eval_count,
"n_trees": flags_obj.n_trees,
"max_depth": flags_obj.max_depth,
}
benchmark_logger = logger.config_benchmark_logger(flags_obj)
benchmark_logger.log_run_info(
model_name="boosted_trees",
dataset_name="higgs",
run_params=run_params,
test_id=flags_obj.benchmark_test_id)
# Though BoostedTreesClassifier is under tf.estimator, faster in-memory
# training is yet provided as a contrib library.
from tensorflow.contrib import estimator as contrib_estimator # pylint: disable=g-import-not-at-top
classifier = contrib_estimator.boosted_trees_classifier_train_in_memory(
train_input_fn,
feature_columns,
model_dir=flags_obj.model_dir or None,
n_trees=flags_obj.n_trees,
max_depth=flags_obj.max_depth,
learning_rate=flags_obj.learning_rate)
# Evaluation.
eval_results = classifier.evaluate(eval_input_fn)
# Benchmark the evaluation results
benchmark_logger.log_evaluation_result(eval_results)
# Exporting the savedmodel with csv parsing.
if flags_obj.export_dir is not None:
classifier.export_savedmodel(
flags_obj.export_dir,
_make_csv_serving_input_receiver_fn(
column_names=feature_names,
# columns are all floats.
column_defaults=[[0.0]] * len(feature_names)),
strip_default_attrs=True)
def main(_):
train_boosted_trees(flags.FLAGS)
def define_train_higgs_flags():
"""Add tree related flags as well as training/eval configuration."""
flags_core.define_base(clean=False, stop_threshold=False, batch_size=False,
num_gpu=False, export_dir=True)
flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core)
flags.DEFINE_integer(
name="train_start", default=0,
help=help_wrap("Start index of train examples within the data."))
flags.DEFINE_integer(
name="train_count", default=1000000,
help=help_wrap("Number of train examples within the data."))
flags.DEFINE_integer(
name="eval_start", default=10000000,
help=help_wrap("Start index of eval examples within the data."))
flags.DEFINE_integer(
name="eval_count", default=1000000,
help=help_wrap("Number of eval examples within the data."))
flags.DEFINE_integer(
"n_trees", default=100, help=help_wrap("Number of trees to build."))
flags.DEFINE_integer(
"max_depth", default=6, help=help_wrap("Maximum depths of each tree."))
flags.DEFINE_float(
"learning_rate", default=0.1,
help=help_wrap("The learning rate."))
flags_core.set_defaults(data_dir="/tmp/higgs_data",
model_dir="/tmp/higgs_model")
if __name__ == "__main__":
# Training progress and eval results are shown as logging.INFO; so enables it.
tf.logging.set_verbosity(tf.logging.INFO)
define_train_higgs_flags()
absl_app.run(main)
![No Maintenance Intended](https://img.shields.io/badge/No%20Maintenance%20Intended-%E2%9C%95-red.svg)
![TensorFlow Requirement: 1.x](https://img.shields.io/badge/TensorFlow%20Requirement-1.x-brightgreen)
![TensorFlow 2 Not Supported](https://img.shields.io/badge/TensorFlow%202%20Not%20Supported-%E2%9C%95-red.svg)
# MNIST in TensorFlow
This directory builds a convolutional neural net to classify the [MNIST
dataset](http://yann.lecun.com/exdb/mnist/) using the
[tf.data](https://www.tensorflow.org/api_docs/python/tf/data),
[tf.estimator.Estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator),
and
[tf.layers](https://www.tensorflow.org/api_docs/python/tf/layers)
APIs.
## Setup
To begin, you'll simply need the latest version of TensorFlow installed.
First make sure you've [added the models folder to your Python path]:
```shell
export PYTHONPATH="$PYTHONPATH:/path/to/models"
```
Otherwise you may encounter an error like `ImportError: No module named official.mnist`.
Then to train the model, run the following:
```
python mnist.py
```
The model will begin training and will automatically evaluate itself on the
validation data.
Illustrative unit tests and benchmarks can be run with:
```
python mnist_test.py
python mnist_test.py --benchmarks=.
```
## Exporting the model
You can export the model into Tensorflow [SavedModel](https://www.tensorflow.org/guide/saved_model) format by using the argument `--export_dir`:
```
python mnist.py --export_dir /tmp/mnist_saved_model
```
The SavedModel will be saved in a timestamped directory under `/tmp/mnist_saved_model/` (e.g. `/tmp/mnist_saved_model/1513630966/`).
**Getting predictions with SavedModel**
Use [`saved_model_cli`](https://www.tensorflow.org/guide/saved_model#cli_to_inspect_and_execute_savedmodel) to inspect and execute the SavedModel.
```
saved_model_cli run --dir /tmp/mnist_saved_model/TIMESTAMP --tag_set serve --signature_def classify --inputs image=examples.npy
```
`examples.npy` contains the data from `example5.png` and `example3.png` in a numpy array, in that order. The array values are normalized to values between 0 and 1.
The output should look similar to below:
```
Result for output key classes:
[5 3]
Result for output key probabilities:
[[ 1.53558474e-07 1.95694142e-13 1.31193523e-09 5.47467265e-03
5.85711526e-22 9.94520664e-01 3.48423509e-06 2.65365645e-17
9.78631419e-07 3.15522470e-08]
[ 1.22413359e-04 5.87615965e-08 1.72251271e-06 9.39960718e-01
3.30306928e-11 2.87386645e-02 2.82353517e-02 8.21146413e-18
2.52568233e-03 4.15460236e-04]]
```
## Experimental: Eager Execution
[Eager execution](https://research.googleblog.com/2017/10/eager-execution-imperative-define-by.html)
(an preview feature in TensorFlow 1.5) is an imperative interface to TensorFlow.
The exact same model defined in `mnist.py` can be trained without creating a
TensorFlow graph using:
```
python mnist_eager.py
```
## Experimental: TPU Acceleration
`mnist.py` (and `mnist_eager.py`) demonstrate training a neural network to
classify digits on CPUs and GPUs. `mnist_tpu.py` can be used to train the
same model using TPUs for hardware acceleration. More information in
the [tensorflow/tpu](https://github.com/tensorflow/tpu) repository.
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""tf.data.Dataset interface to the MNIST dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
import shutil
import tempfile
import numpy as np
from six.moves import urllib
import tensorflow as tf
def read32(bytestream):
"""Read 4 bytes from bytestream as an unsigned 32-bit integer."""
dt = np.dtype(np.uint32).newbyteorder('>')
return np.frombuffer(bytestream.read(4), dtype=dt)[0]
def check_image_file_header(filename):
"""Validate that filename corresponds to images for the MNIST dataset."""
with tf.io.gfile.GFile(filename, 'rb') as f:
magic = read32(f)
read32(f) # num_images, unused
rows = read32(f)
cols = read32(f)
if magic != 2051:
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
f.name))
if rows != 28 or cols != 28:
raise ValueError(
'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %
(f.name, rows, cols))
def check_labels_file_header(filename):
"""Validate that filename corresponds to labels for the MNIST dataset."""
with tf.io.gfile.GFile(filename, 'rb') as f:
magic = read32(f)
read32(f) # num_items, unused
if magic != 2049:
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
f.name))
def download(directory, filename):
"""Download (and unzip) a file from the MNIST dataset if not already done."""
filepath = os.path.join(directory, filename)
if tf.io.gfile.exists(filepath):
return filepath
if not tf.io.gfile.exists(directory):
tf.io.gfile.makedirs(directory)
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
_, zipped_filepath = tempfile.mkstemp(suffix='.gz')
print('Downloading %s to %s' % (url, zipped_filepath))
urllib.request.urlretrieve(url, zipped_filepath)
with gzip.open(zipped_filepath, 'rb') as f_in, \
tf.io.gfile.GFile(filepath, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
os.remove(zipped_filepath)
return filepath
def dataset(directory, images_file, labels_file):
"""Download and parse MNIST dataset."""
images_file = download(directory, images_file)
labels_file = download(directory, labels_file)
check_image_file_header(images_file)
check_labels_file_header(labels_file)
def decode_image(image):
# Normalize from [0, 255] to [0.0, 1.0]
image = tf.io.decode_raw(image, tf.uint8)
image = tf.cast(image, tf.float32)
image = tf.reshape(image, [784])
return image / 255.0
def decode_label(label):
label = tf.io.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]
label = tf.reshape(label, []) # label is a scalar
return tf.cast(label, tf.int32)
images = tf.data.FixedLengthRecordDataset(
images_file, 28 * 28, header_bytes=16).map(decode_image)
labels = tf.data.FixedLengthRecordDataset(
labels_file, 1, header_bytes=8).map(decode_label)
return tf.data.Dataset.zip((images, labels))
def train(directory):
"""tf.data.Dataset object for MNIST training data."""
return dataset(directory, 'train-images-idx3-ubyte',
'train-labels-idx1-ubyte')
def test(directory):
"""tf.data.Dataset object for MNIST test data."""
return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app as absl_app
from absl import flags
from absl import logging
from six.moves import range
import tensorflow.compat.v1 as tf
from official.r1.mnist import dataset
from official.r1.utils.logs import hooks_helper
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers
LEARNING_RATE = 1e-4
def create_model(data_format):
"""Model to recognize digits in the MNIST dataset.
Network structure is equivalent to:
https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
But uses the tf.keras API.
Args:
data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
typically faster on GPUs while 'channels_last' is typically faster on
CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats
Returns:
A tf.keras.Model.
"""
if data_format == 'channels_first':
input_shape = [1, 28, 28]
else:
assert data_format == 'channels_last'
input_shape = [28, 28, 1]
l = tf.keras.layers
max_pool = l.MaxPooling2D(
(2, 2), (2, 2), padding='same', data_format=data_format)
# The model consists of a sequential chain of layers, so tf.keras.Sequential
# (a subclass of tf.keras.Model) makes for a compact description.
return tf.keras.Sequential(
[
l.Reshape(
target_shape=input_shape,
input_shape=(28 * 28,)),
l.Conv2D(
32,
5,
padding='same',
data_format=data_format,
activation=tf.nn.relu),
max_pool,
l.Conv2D(
64,
5,
padding='same',
data_format=data_format,
activation=tf.nn.relu),
max_pool,
l.Flatten(),
l.Dense(1024, activation=tf.nn.relu),
l.Dropout(0.4),
l.Dense(10)
])
def define_mnist_flags():
"""Defines flags for mnist."""
flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True, stop_threshold=True,
num_gpu=True, hooks=True, export_dir=True,
distribution_strategy=True)
flags_core.define_performance(inter_op=True, intra_op=True,
num_parallel_calls=False,
all_reduce_alg=True)
flags_core.define_image()
flags.adopt_module_key_flags(flags_core)
flags_core.set_defaults(data_dir='/tmp/mnist_data',
model_dir='/tmp/mnist_model',
batch_size=100,
train_epochs=40)
def model_fn(features, labels, mode, params):
"""The model_fn argument for creating an Estimator."""
model = create_model(params['data_format'])
image = features
if isinstance(image, dict):
image = features['image']
if mode == tf.estimator.ModeKeys.PREDICT:
logits = model(image, training=False)
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits),
}
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
'classify': tf.estimator.export.PredictOutput(predictions)
})
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=LEARNING_RATE)
logits = model(image, training=True)
loss = tf.compat.v1.losses.sparse_softmax_cross_entropy(labels=labels,
logits=logits)
accuracy = tf.compat.v1.metrics.accuracy(
labels=labels, predictions=tf.argmax(logits, axis=1))
# Name tensors to be logged with LoggingTensorHook.
tf.identity(LEARNING_RATE, 'learning_rate')
tf.identity(loss, 'cross_entropy')
tf.identity(accuracy[1], name='train_accuracy')
# Save accuracy scalar to Tensorboard output.
tf.summary.scalar('train_accuracy', accuracy[1])
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=loss,
train_op=optimizer.minimize(
loss,
tf.compat.v1.train.get_or_create_global_step()))
if mode == tf.estimator.ModeKeys.EVAL:
logits = model(image, training=False)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL,
loss=loss,
eval_metric_ops={
'accuracy':
tf.metrics.accuracy(
labels=labels, predictions=tf.argmax(logits, axis=1)),
})
def run_mnist(flags_obj):
"""Run MNIST training and eval loop.
Args:
flags_obj: An object containing parsed flag values.
"""
model_helpers.apply_clean(flags_obj)
model_function = model_fn
session_config = tf.compat.v1.ConfigProto(
inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
allow_soft_placement=True)
distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_core.get_num_gpus(flags_obj),
all_reduce_alg=flags_obj.all_reduce_alg)
run_config = tf.estimator.RunConfig(
train_distribute=distribution_strategy, session_config=session_config)
data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
else 'channels_last')
mnist_classifier = tf.estimator.Estimator(
model_fn=model_function,
model_dir=flags_obj.model_dir,
config=run_config,
params={
'data_format': data_format,
})
# Set up training and evaluation input functions.
def train_input_fn():
"""Prepare data for training."""
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch.
ds = dataset.train(flags_obj.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)
# Iterate through the dataset a set number (`epochs_between_evals`) of times
# during each training session.
ds = ds.repeat(flags_obj.epochs_between_evals)
return ds
def eval_input_fn():
return dataset.test(flags_obj.data_dir).batch(
flags_obj.batch_size).make_one_shot_iterator().get_next()
# Set up hook that outputs training logs every 100 steps.
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks, model_dir=flags_obj.model_dir,
batch_size=flags_obj.batch_size)
# Train and evaluate model.
for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results)
if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
eval_results['accuracy']):
break
# Export the model
if flags_obj.export_dir is not None:
image = tf.compat.v1.placeholder(tf.float32, [None, 28, 28])
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'image': image,
})
mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn,
strip_default_attrs=True)
def main(_):
run_mnist(flags.FLAGS)
if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
define_mnist_flags()
absl_app.run(main)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MNIST model training with TensorFlow eager execution.
See:
https://research.googleblog.com/2017/10/eager-execution-imperative-define-by.html
This program demonstrates training of the convolutional neural network model
defined in mnist.py with eager execution enabled.
If you are not interested in eager execution, you should ignore this file.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
# pylint: disable=g-bad-import-order
from absl import app as absl_app
from absl import flags
from six.moves import range
from six.moves import zip
import tensorflow as tf
from tensorflow.python import eager as tfe
# pylint: enable=g-bad-import-order
from official.r1.mnist import dataset as mnist_dataset
from official.r1.mnist import mnist
from official.utils.flags import core as flags_core
from official.utils.misc import model_helpers
def loss(logits, labels):
return tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels))
def compute_accuracy(logits, labels):
predictions = tf.argmax(logits, axis=1, output_type=tf.int64)
labels = tf.cast(labels, tf.int64)
batch_size = int(logits.shape[0])
return tf.reduce_sum(
tf.cast(tf.equal(predictions, labels), dtype=tf.float32)) / batch_size
def train(model, optimizer, dataset, step_counter, log_interval=None):
"""Trains model on `dataset` using `optimizer`."""
from tensorflow.contrib import summary as contrib_summary # pylint: disable=g-import-not-at-top
start = time.time()
for (batch, (images, labels)) in enumerate(dataset):
with contrib_summary.record_summaries_every_n_global_steps(
10, global_step=step_counter):
# Record the operations used to compute the loss given the input,
# so that the gradient of the loss with respect to the variables
# can be computed.
with tf.GradientTape() as tape:
logits = model(images, training=True)
loss_value = loss(logits, labels)
contrib_summary.scalar('loss', loss_value)
contrib_summary.scalar('accuracy',
compute_accuracy(logits, labels))
grads = tape.gradient(loss_value, model.variables)
optimizer.apply_gradients(
list(zip(grads, model.variables)), global_step=step_counter)
if log_interval and batch % log_interval == 0:
rate = log_interval / (time.time() - start)
print('Step #%d\tLoss: %.6f (%d steps/sec)' % (batch, loss_value, rate))
start = time.time()
def test(model, dataset):
"""Perform an evaluation of `model` on the examples from `dataset`."""
from tensorflow.contrib import summary as contrib_summary # pylint: disable=g-import-not-at-top
avg_loss = tf.keras.metrics.Mean('loss', dtype=tf.float32)
accuracy = tf.keras.metrics.Accuracy('accuracy', dtype=tf.float32)
for (images, labels) in dataset:
logits = model(images, training=False)
avg_loss.update_state(loss(logits, labels))
accuracy.update_state(
tf.argmax(logits, axis=1, output_type=tf.int64),
tf.cast(labels, tf.int64))
print('Test set: Average loss: %.4f, Accuracy: %4f%%\n' %
(avg_loss.result(), 100 * accuracy.result()))
with contrib_summary.always_record_summaries():
contrib_summary.scalar('loss', avg_loss.result())
contrib_summary.scalar('accuracy', accuracy.result())
def run_mnist_eager(flags_obj):
"""Run MNIST training and eval loop in eager mode.
Args:
flags_obj: An object containing parsed flag values.
"""
tf.enable_eager_execution()
model_helpers.apply_clean(flags.FLAGS)
# Automatically determine device and data_format
(device, data_format) = ('/gpu:0', 'channels_first')
if flags_obj.no_gpu or not tf.test.is_gpu_available():
(device, data_format) = ('/cpu:0', 'channels_last')
# If data_format is defined in FLAGS, overwrite automatically set value.
if flags_obj.data_format is not None:
data_format = flags_obj.data_format
print('Using device %s, and data format %s.' % (device, data_format))
# Load the datasets
train_ds = mnist_dataset.train(flags_obj.data_dir).shuffle(60000).batch(
flags_obj.batch_size)
test_ds = mnist_dataset.test(flags_obj.data_dir).batch(
flags_obj.batch_size)
# Create the model and optimizer
model = mnist.create_model(data_format)
optimizer = tf.train.MomentumOptimizer(flags_obj.lr, flags_obj.momentum)
# Create file writers for writing TensorBoard summaries.
if flags_obj.output_dir:
# Create directories to which summaries will be written
# tensorboard --logdir=<output_dir>
# can then be used to see the recorded summaries.
train_dir = os.path.join(flags_obj.output_dir, 'train')
test_dir = os.path.join(flags_obj.output_dir, 'eval')
tf.gfile.MakeDirs(flags_obj.output_dir)
else:
train_dir = None
test_dir = None
summary_writer = tf.compat.v2.summary.create_file_writer(
train_dir, flush_millis=10000)
test_summary_writer = tf.compat.v2.summary.create_file_writer(
test_dir, flush_millis=10000, name='test')
# Create and restore checkpoint (if one exists on the path)
checkpoint_prefix = os.path.join(flags_obj.model_dir, 'ckpt')
step_counter = tf.train.get_or_create_global_step()
checkpoint = tf.train.Checkpoint(
model=model, optimizer=optimizer, step_counter=step_counter)
# Restore variables on creation if a checkpoint exists.
checkpoint.restore(tf.train.latest_checkpoint(flags_obj.model_dir))
# Train and evaluate for a set number of epochs.
with tf.device(device):
for _ in range(flags_obj.train_epochs):
start = time.time()
with summary_writer.as_default():
train(model, optimizer, train_ds, step_counter,
flags_obj.log_interval)
end = time.time()
print('\nTrain time for epoch #%d (%d total steps): %f' %
(checkpoint.save_counter.numpy() + 1,
step_counter.numpy(),
end - start))
with test_summary_writer.as_default():
test(model, test_ds)
checkpoint.save(checkpoint_prefix)
def define_mnist_eager_flags():
"""Defined flags and defaults for MNIST in eager mode."""
flags_core.define_base(clean=True, train_epochs=True, export_dir=True,
distribution_strategy=True)
flags_core.define_image()
flags.adopt_module_key_flags(flags_core)
flags.DEFINE_integer(
name='log_interval', short_name='li', default=10,
help=flags_core.help_wrap('batches between logging training status'))
flags.DEFINE_string(
name='output_dir', short_name='od', default=None,
help=flags_core.help_wrap('Directory to write TensorBoard summaries'))
flags.DEFINE_float(name='learning_rate', short_name='lr', default=0.01,
help=flags_core.help_wrap('Learning rate.'))
flags.DEFINE_float(name='momentum', short_name='m', default=0.5,
help=flags_core.help_wrap('SGD momentum.'))
flags.DEFINE_bool(name='no_gpu', short_name='nogpu', default=False,
help=flags_core.help_wrap(
'disables GPU usage even if a GPU is available'))
flags_core.set_defaults(
data_dir='/tmp/tensorflow/mnist/input_data',
model_dir='/tmp/tensorflow/mnist/checkpoints/',
batch_size=100,
train_epochs=10,
)
def main(_):
run_mnist_eager(flags.FLAGS)
if __name__ == '__main__':
define_mnist_eager_flags()
absl_app.run(main=main)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import tensorflow.compat.v1 as tf # pylint: disable=g-bad-import-order
from absl import logging
from official.r1.mnist import mnist
BATCH_SIZE = 100
def dummy_input_fn():
image = tf.random.uniform([BATCH_SIZE, 784])
labels = tf.random.uniform([BATCH_SIZE, 1], maxval=9, dtype=tf.int32)
return image, labels
def make_estimator():
data_format = 'channels_last'
if tf.test.is_built_with_cuda():
data_format = 'channels_first'
return tf.estimator.Estimator(
model_fn=mnist.model_fn, params={
'data_format': data_format
})
class Tests(tf.test.TestCase):
"""Run tests for MNIST model.
MNIST uses contrib and will not work with TF 2.0. All tests are disabled if
using TF 2.0.
"""
def test_mnist(self):
classifier = make_estimator()
classifier.train(input_fn=dummy_input_fn, steps=2)
eval_results = classifier.evaluate(input_fn=dummy_input_fn, steps=1)
loss = eval_results['loss']
global_step = eval_results['global_step']
accuracy = eval_results['accuracy']
self.assertEqual(loss.shape, ())
self.assertEqual(2, global_step)
self.assertEqual(accuracy.shape, ())
input_fn = lambda: tf.random.uniform([3, 784])
predictions_generator = classifier.predict(input_fn)
for _ in range(3):
predictions = next(predictions_generator)
self.assertEqual(predictions['probabilities'].shape, (10,))
self.assertEqual(predictions['classes'].shape, ())
def mnist_model_fn_helper(self, mode, multi_gpu=False):
features, labels = dummy_input_fn()
image_count = features.shape[0]
spec = mnist.model_fn(features, labels, mode, {
'data_format': 'channels_last',
'multi_gpu': multi_gpu
})
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape, (image_count, 10))
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
self.assertAllEqual(predictions['classes'].shape, (image_count,))
self.assertEqual(predictions['classes'].dtype, tf.int64)
if mode != tf.estimator.ModeKeys.PREDICT:
loss = spec.loss
self.assertAllEqual(loss.shape, ())
self.assertEqual(loss.dtype, tf.float32)
if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = spec.eval_metric_ops
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
def test_mnist_model_fn_train_mode(self):
self.mnist_model_fn_helper(tf.estimator.ModeKeys.TRAIN)
def test_mnist_model_fn_train_mode_multi_gpu(self):
self.mnist_model_fn_helper(tf.estimator.ModeKeys.TRAIN, multi_gpu=True)
def test_mnist_model_fn_eval_mode(self):
self.mnist_model_fn_helper(tf.estimator.ModeKeys.EVAL)
def test_mnist_model_fn_predict_mode(self):
self.mnist_model_fn_helper(tf.estimator.ModeKeys.PREDICT)
class Benchmarks(tf.test.Benchmark):
"""Simple speed benchmarking for MNIST."""
def benchmark_train_step_time(self):
classifier = make_estimator()
# Run one step to warmup any use of the GPU.
classifier.train(input_fn=dummy_input_fn, steps=1)
have_gpu = tf.test.is_gpu_available()
num_steps = 1000 if have_gpu else 100
name = 'train_step_time_%s' % ('gpu' if have_gpu else 'cpu')
start = time.time()
classifier.train(input_fn=dummy_input_fn, steps=num_steps)
end = time.time()
wall_time = (end - start) / num_steps
self.report_benchmark(
iters=num_steps,
wall_time=wall_time,
name=name,
extras={
'examples_per_sec': BATCH_SIZE / wall_time
})
if __name__ == '__main__':
logging.set_verbosity(logging.ERROR)
tf.disable_v2_behavior()
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MNIST model training using TPUs.
This program demonstrates training of the convolutional neural network model
defined in mnist.py on Google Cloud TPUs (https://cloud.google.com/tpu/).
If you are not interested in TPUs, you should ignore this file.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
# pylint: disable=g-bad-import-order
from absl import app as absl_app # pylint: disable=unused-import
import tensorflow.compat.v1 as tf
# pylint: enable=g-bad-import-order
# For open source environment, add grandparent directory for import
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(sys.path[0]))))
from official.r1.mnist import dataset # pylint: disable=wrong-import-position
from official.r1.mnist import mnist # pylint: disable=wrong-import-position
# Cloud TPU Cluster Resolver flags
tf.flags.DEFINE_string(
"tpu", default=None,
help="The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url.")
tf.flags.DEFINE_string(
"tpu_zone", default=None,
help="[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
tf.flags.DEFINE_string(
"gcp_project", default=None,
help="[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
# Model specific parameters
tf.flags.DEFINE_string("data_dir", "",
"Path to directory containing the MNIST dataset")
tf.flags.DEFINE_string("model_dir", None, "Estimator model_dir")
tf.flags.DEFINE_integer("batch_size", 1024,
"Mini-batch size for the training. Note that this "
"is the global batch size and not the per-shard batch.")
tf.flags.DEFINE_integer("train_steps", 1000, "Total number of training steps.")
tf.flags.DEFINE_integer("eval_steps", 0,
"Total number of evaluation steps. If `0`, evaluation "
"after training is skipped.")
tf.flags.DEFINE_float("learning_rate", 0.05, "Learning rate.")
tf.flags.DEFINE_bool("use_tpu", True, "Use TPUs rather than plain CPUs")
tf.flags.DEFINE_bool("enable_predict", True, "Do some predictions at the end")
tf.flags.DEFINE_integer("iterations", 50,
"Number of iterations per TPU training loop.")
tf.flags.DEFINE_integer("num_shards", 8, "Number of shards (TPU chips).")
FLAGS = tf.flags.FLAGS
def metric_fn(labels, logits):
accuracy = tf.metrics.accuracy(
labels=labels, predictions=tf.argmax(logits, axis=1))
return {"accuracy": accuracy}
def model_fn(features, labels, mode, params):
"""model_fn constructs the ML model used to predict handwritten digits."""
del params
image = features
if isinstance(image, dict):
image = features["image"]
model = mnist.create_model("channels_last")
if mode == tf.estimator.ModeKeys.PREDICT:
logits = model(image, training=False)
predictions = {
'class_ids': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits),
}
return tf.estimator.tpu.TPUEstimatorSpec(mode, predictions=predictions)
logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN))
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
if mode == tf.estimator.ModeKeys.TRAIN:
learning_rate = tf.train.exponential_decay(
FLAGS.learning_rate,
tf.train.get_global_step(),
decay_steps=100000,
decay_rate=0.96)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
if FLAGS.use_tpu:
optimizer = tf.tpu.CrossShardOptimizer(optimizer)
return tf.estimator.tpu.TPUEstimatorSpec(
mode=mode,
loss=loss,
train_op=optimizer.minimize(loss, tf.train.get_global_step()))
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.tpu.TPUEstimatorSpec(
mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits]))
def train_input_fn(params):
"""train_input_fn defines the input pipeline used for training."""
batch_size = params["batch_size"]
data_dir = params["data_dir"]
# Retrieves the batch size for the current shard. The # of shards is
# computed according to the input pipeline deployment. See
# `tf.estimator.tpu.RunConfig` for details.
ds = dataset.train(data_dir).cache().repeat().shuffle(
buffer_size=50000).batch(batch_size, drop_remainder=True)
return ds
def eval_input_fn(params):
batch_size = params["batch_size"]
data_dir = params["data_dir"]
ds = dataset.test(data_dir).batch(batch_size, drop_remainder=True)
return ds
def predict_input_fn(params):
batch_size = params["batch_size"]
data_dir = params["data_dir"]
# Take out top 10 samples from test data to make the predictions.
ds = dataset.test(data_dir).take(10).batch(batch_size)
return ds
def main(argv):
del argv # Unused.
tf.logging.set_verbosity(tf.logging.INFO)
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
FLAGS.tpu,
zone=FLAGS.tpu_zone,
project=FLAGS.gcp_project
)
run_config = tf.estimator.tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=FLAGS.model_dir,
session_config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True),
tpu_config=tf.estimator.tpu.TPUConfig(FLAGS.iterations, FLAGS.num_shards),
)
estimator = tf.estimator.tpu.TPUEstimator(
model_fn=model_fn,
use_tpu=FLAGS.use_tpu,
train_batch_size=FLAGS.batch_size,
eval_batch_size=FLAGS.batch_size,
predict_batch_size=FLAGS.batch_size,
params={"data_dir": FLAGS.data_dir},
config=run_config)
# TPUEstimator.train *requires* a max_steps argument.
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps)
# TPUEstimator.evaluate *requires* a steps argument.
# Note that the number of examples used during evaluation is
# --eval_steps * --batch_size.
# So if you change --batch_size then change --eval_steps too.
if FLAGS.eval_steps:
estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.eval_steps)
# Run prediction on top few samples of test data.
if FLAGS.enable_predict:
predictions = estimator.predict(input_fn=predict_input_fn)
for pred_dict in predictions:
template = ('Prediction is "{}" ({:.1f}%).')
class_id = pred_dict['class_ids']
probability = pred_dict['probabilities'][class_id]
print(template.format(class_id, 100 * probability))
if __name__ == "__main__":
tf.disable_v2_behavior()
absl_app.run(main)
![No Maintenance Intended](https://img.shields.io/badge/No%20Maintenance%20Intended-%E2%9C%95-red.svg)
![TensorFlow Requirement: 1.x](https://img.shields.io/badge/TensorFlow%20Requirement-1.x-brightgreen)
![TensorFlow 2 Not Supported](https://img.shields.io/badge/TensorFlow%202%20Not%20Supported-%E2%9C%95-red.svg)
# NCF Estimator implementation
NCF framework to train and evaluate the NeuMF model
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""NCF framework to train and evaluate the NeuMF model.
The NeuMF model assembles both MF and MLP models under the NCF framework. Check
`neumf_model.py` for more details about the models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import heapq
import json
import math
import multiprocessing
import os
import signal
from absl import app as absl_app
from absl import flags
from absl import logging
import numpy as np
from six.moves import range
import tensorflow as tf
import typing
from official.r1.utils.logs import hooks_helper
from official.r1.utils.logs import logger
from official.r1.utils.logs import mlperf_helper
from official.recommendation import constants as rconst
from official.recommendation import data_pipeline
from official.recommendation import data_preprocessing
from official.recommendation import movielens
from official.recommendation import ncf_common
from official.recommendation import neumf_model
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers
FLAGS = flags.FLAGS
def construct_estimator(model_dir, params):
"""Construct either an Estimator for NCF.
Args:
model_dir: The model directory for the estimator
params: The params dict for the estimator
Returns:
An Estimator.
"""
distribution = ncf_common.get_v1_distribution_strategy(params)
run_config = tf.estimator.RunConfig(train_distribute=distribution,
eval_distribute=distribution)
model_fn = neumf_model.neumf_model_fn
estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
config=run_config, params=params)
return estimator
def log_and_get_hooks(eval_batch_size):
"""Convenience function for hook and logger creation."""
# Create hooks that log information about the training and metric values
train_hooks = hooks_helper.get_train_hooks(
FLAGS.hooks,
model_dir=FLAGS.model_dir,
batch_size=FLAGS.batch_size, # for ExamplesPerSecondHook
tensors_to_log={"cross_entropy": "cross_entropy"}
)
run_params = {
"batch_size": FLAGS.batch_size,
"eval_batch_size": eval_batch_size,
"number_factors": FLAGS.num_factors,
"hr_threshold": FLAGS.hr_threshold,
"train_epochs": FLAGS.train_epochs,
}
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info(
model_name="recommendation",
dataset_name=FLAGS.dataset,
run_params=run_params,
test_id=FLAGS.benchmark_test_id)
return benchmark_logger, train_hooks
def main(_):
with logger.benchmark_context(FLAGS), \
mlperf_helper.LOGGER(FLAGS.output_ml_perf_compliance_logging):
mlperf_helper.set_ncf_root(os.path.split(os.path.abspath(__file__))[0])
run_ncf(FLAGS)
def run_ncf(_):
"""Run NCF training and eval loop."""
params = ncf_common.parse_flags(FLAGS)
num_users, num_items, num_train_steps, num_eval_steps, producer = (
ncf_common.get_inputs(params))
params["num_users"], params["num_items"] = num_users, num_items
producer.start()
model_helpers.apply_clean(flags.FLAGS)
estimator = construct_estimator(model_dir=FLAGS.model_dir, params=params)
benchmark_logger, train_hooks = log_and_get_hooks(params["eval_batch_size"])
total_training_cycle = FLAGS.train_epochs // FLAGS.epochs_between_evals
target_reached = False
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.TRAIN_LOOP)
for cycle_index in range(total_training_cycle):
assert FLAGS.epochs_between_evals == 1 or not mlperf_helper.LOGGER.enabled
logging.info("Starting a training cycle: {}/{}".format(
cycle_index + 1, total_training_cycle))
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.TRAIN_EPOCH,
value=cycle_index)
train_input_fn = producer.make_input_fn(is_training=True)
estimator.train(input_fn=train_input_fn, hooks=train_hooks,
steps=num_train_steps)
logging.info("Beginning evaluation.")
eval_input_fn = producer.make_input_fn(is_training=False)
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_START,
value=cycle_index)
eval_results = estimator.evaluate(eval_input_fn, steps=num_eval_steps)
logging.info("Evaluation complete.")
hr = float(eval_results[rconst.HR_KEY])
ndcg = float(eval_results[rconst.NDCG_KEY])
loss = float(eval_results["loss"])
mlperf_helper.ncf_print(
key=mlperf_helper.TAGS.EVAL_TARGET,
value={"epoch": cycle_index, "value": FLAGS.hr_threshold})
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_ACCURACY,
value={"epoch": cycle_index, "value": hr})
mlperf_helper.ncf_print(
key=mlperf_helper.TAGS.EVAL_HP_NUM_NEG,
value={"epoch": cycle_index, "value": rconst.NUM_EVAL_NEGATIVES})
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_STOP, value=cycle_index)
# Benchmark the evaluation results
benchmark_logger.log_evaluation_result(eval_results)
# Log the HR and NDCG results.
logging.info(
"Iteration {}: HR = {:.4f}, NDCG = {:.4f}, Loss = {:.4f}".format(
cycle_index + 1, hr, ndcg, loss))
# If some evaluation threshold is met
if model_helpers.past_stop_threshold(FLAGS.hr_threshold, hr):
target_reached = True
break
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.RUN_STOP,
value={"success": target_reached})
producer.stop_loop()
producer.join()
# Clear the session explicitly to avoid session delete error
tf.keras.backend.clear_session()
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.RUN_FINAL)
if __name__ == "__main__":
logging.set_verbosity(logging.INFO)
ncf_common.define_ncf_flags()
absl_app.run(main)
![No Maintenance Intended](https://img.shields.io/badge/No%20Maintenance%20Intended-%E2%9C%95-red.svg)
![TensorFlow Requirement: 1.x](https://img.shields.io/badge/TensorFlow%20Requirement-1.x-brightgreen)
![TensorFlow 2 Not Supported](https://img.shields.io/badge/TensorFlow%202%20Not%20Supported-%E2%9C%95-red.svg)
# ResNet in TensorFlow
Deep residual networks, or ResNets for short, provided the breakthrough idea of
identity mappings in order to enable training of very deep convolutional neural
networks. This folder contains an implementation of ResNet for the ImageNet
dataset written in TensorFlow.
See the following papers for more background:
[1] [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385.pdf) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
[2] [Identity Mappings in Deep Residual Networks](https://arxiv.org/pdf/1603.05027.pdf) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
In code, v1 refers to the ResNet defined in [1] but where a stride 2 is used on
the 3x3 conv rather than the first 1x1 in the bottleneck. This change results
in higher and more stable accuracy with less epochs than the original v1 and has
shown to scale to higher batch sizes with minimal degradation in accuracy.
There is no originating paper. The first mention we are aware of was in the
torch version of [ResNetv1](https://github.com/facebook/fb.resnet.torch). Most
popular v1 implementations are this implementation which we call ResNetv1.5.
In testing we found v1.5 requires ~12% more compute to train and has 6% reduced
throughput for inference compared to ResNetv1. CIFAR-10 ResNet does not use the
bottleneck and is thus the same for v1 as v1.5.
v2 refers to [2]. The principle difference between the two versions is that v1
applies batch normalization and activation after convolution, while v2 applies
batch normalization, then activation, and finally convolution. A schematic
comparison is presented in Figure 1 (left) of [2].
Please proceed according to which dataset you would like to train/evaluate on:
## CIFAR-10
### Setup
You need to have the latest version of TensorFlow installed.
First, make sure [the models folder is in your Python path](/official/#running-the-models); otherwise you may encounter `ImportError: No module named official.resnet`.
Then, download and extract the CIFAR-10 data from Alex's website, specifying the location with the `--data_dir` flag. Run the following:
```bash
python cifar10_download_and_extract.py --data_dir <DATA_DIR>
```
Then, to train the model:
```bash
python cifar10_main.py --data_dir <DATA_DIR>/cifar-10-batches-bin --model_dir <MODEL_DIR>
```
Use `--data_dir` to specify the location of the CIFAR-10 data used in the previous step. There are more flag options as described in `cifar10_main.py`.
To export a `SavedModel` from the trained checkpoint:
```bash
python cifar10_main.py --data_dir <DATA_DIR>/cifar-10-batches-bin --model_dir <MODEL_DIR> --eval_only --export_dir <EXPORT_DIR>
```
Note: The `<EXPORT_DIR>` must be present. You might want to run `mkdir <EXPORT_DIR>` beforehand.
The `SavedModel` can then be [loaded](https://www.tensorflow.org/guide/saved_model#loading_a_savedmodel_in_python) in order to use the ResNet for prediction.
## ImageNet
### Setup
To begin, you will need to download the ImageNet dataset and convert it to
TFRecord format. The following [script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py)
and [README](https://github.com/tensorflow/tpu/tree/master/tools/datasets#imagenet_to_gcspy)
provide a few options.
Once your dataset is ready, you can begin training the model as follows:
```bash
python imagenet_main.py --data_dir=/path/to/imagenet
```
The model will begin training and will automatically evaluate itself on the
validation data roughly once per epoch.
Note that there are a number of other options you can specify, including
`--model_dir` to choose where to store the model and `--resnet_size` to choose
the model size (options include ResNet-18 through ResNet-200). See
[`resnet_run_loop.py`](resnet_run_loop.py) for the full list of options.
## Compute Devices
Training is accomplished using the DistributionStrategies API. (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/distribute/README.md)
The appropriate distribution strategy is chosen based on the `--num_gpus` flag.
By default this flag is one if TensorFlow is compiled with CUDA, and zero
otherwise.
num_gpus:
+ 0: Use OneDeviceStrategy and train on CPU.
+ 1: Use OneDeviceStrategy and train on GPU.
+ 2+: Use MirroredStrategy (data parallelism) to distribute a batch between devices.
### Pre-trained model
You can download pre-trained versions of ResNet-50. Reported accuracies are top-1 single-crop accuracy for the ImageNet validation set.
Models are reported as both checkpoints produced by Estimator during training, and as SavedModels which are more portable. Checkpoints are fragile,
and these are not guaranteed to work with future versions of the code. Both ResNet v1
and ResNet v2 have been trained in both fp16 and fp32 precision. (Here v1 refers to "v1.5". See the note above.) Furthermore, SavedModels
are generated to accept either tensor or JPG inputs, and with channels_first (NCHW) and channels_last (NHWC) convolutions. NCHW is generally
better for GPUs, while NHWC is generally better for CPUs. See the TensorFlow [performance guide](https://www.tensorflow.org/performance/performance_guide#data_formats)
for more details.
ResNet-50 v2 (fp32, Accuracy 76.47%):
* [Checkpoint](http://download.tensorflow.org/models/official/20181001_resnet/checkpoints/resnet_imagenet_v2_fp32_20181001.tar.gz)
* SavedModel [(NCHW)](http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp32_savedmodel_NCHW.tar.gz),
[(NCHW, JPG)](http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp32_savedmodel_NCHW_jpg.tar.gz),
[(NHWC)](http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp32_savedmodel_NHWC.tar.gz),
[(NHWC, JPG)](http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp32_savedmodel_NHWC_jpg.tar.gz)
ResNet-50 v2 (fp16, Accuracy 76.56%):
* [Checkpoint](http://download.tensorflow.org/models/official/20181001_resnet/checkpoints/resnet_imagenet_v2_fp16_20180928.tar.gz)
* SavedModel [(NCHW)](http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp16_savedmodel_NCHW.tar.gz),
[(NCHW, JPG)](http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp16_savedmodel_NCHW_jpg.tar.gz),
[(NHWC)](http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp16_savedmodel_NHWC.tar.gz),
[(NHWC, JPG)](http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp16_savedmodel_NHWC_jpg.tar.gz)
ResNet-50 v1 (fp32, Accuracy 76.53%):
* [Checkpoint](http://download.tensorflow.org/models/official/20181001_resnet/checkpoints/resnet_imagenet_v1_fp32_20181001.tar.gz)
* SavedModel [(NCHW)](http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v1_fp32_savedmodel_NCHW.tar.gz),
[(NCHW, JPG)](http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v1_fp32_savedmodel_NCHW_jpg.tar.gz),
[(NHWC)](http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v1_fp32_savedmodel_NHWC.tar.gz),
[(NHWC, JPG)](http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v1_fp32_savedmodel_NHWC_jpg.tar.gz)
ResNet-50 v1 (fp16, Accuracy 76.18%):
* [Checkpoint](http://download.tensorflow.org/models/official/20181001_resnet/checkpoints/resnet_imagenet_v1_fp16_20181001.tar.gz)
* SavedModel [(NCHW)](http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v1_fp16_savedmodel_NCHW.tar.gz),
[(NCHW, JPG)](http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v1_fp16_savedmodel_NCHW_jpg.tar.gz),
[(NHWC)](http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v1_fp16_savedmodel_NHWC.tar.gz),
[(NHWC, JPG)](http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v1_fp16_savedmodel_NHWC_jpg.tar.gz)
### Transfer Learning
You can use a pretrained model to initialize a training process. In addition you are able to freeze all but the final fully connected layers to fine tune your model. Transfer Learning is useful when training on your own small datasets. For a brief look at transfer learning in the context of convolutional neural networks, we recommend reading these [short notes](http://cs231n.github.io/transfer-learning/).
To fine tune a pretrained resnet you must make three changes to your training procedure:
1) Build the exact same model as previously except we change the number of labels in the final classification layer.
2) Restore all weights from the pre-trained resnet except for the final classification layer; this will get randomly initialized instead.
3) Freeze earlier layers of the network
We can perform these three operations by specifying two flags: ```--pretrained_model_checkpoint_path``` and ```--fine_tune```. The first flag is a string that points to the path of a pre-trained resnet model. If this flag is specified, it will load all but the final classification layer. A key thing to note: if both ```--pretrained_model_checkpoint_path``` and a non empty ```model_dir``` directory are passed, the tensorflow estimator will load only the ```model_dir```. For more on this please see [WarmStartSettings](https://www.tensorflow.org/versions/master/api_docs/python/tf/estimator/WarmStartSettings) and [Estimators](https://www.tensorflow.org/guide/estimators).
The second flag ```--fine_tune``` is a boolean that indicates whether earlier layers of the network should be frozen. You may set this flag to false if you wish to continue training a pre-trained model from a checkpoint. If you set this flag to true, you can train a new classification layer from scratch.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment