Unverified Commit 587f5792 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Add reference data tests to official. (#3723)

* Add golden test util to streamline symbolic and numerical comparison to reference graphs, and apply golden tests to ResNet.

update tests

use more concise logic for path property

delint

add some comments

delint

address PR comments

make resnet tests more concise, and supress warning test in py2

change resnet name template

more shuffling of data dirs

address PR comments and add tensorflow version info

Remove subTest due to py2

switch from tf.__version__ to tf.VERSION, and include tf.GIT_VERSION

supress lint error from json load unpack

* address PR comments

* address PR comments

* delint
parent 1730eed4
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test that the definitions of ResNet layers haven't changed.
These tests will fail if either:
a) The graph of a resnet layer changes and the change is significant enough
that it can no longer load existing checkpoints.
b) The numerical results produced by the layer change.
A warning will be issued if the graph changes, but the checkpoint still loads.
In the event that a layer change is intended, or the TensorFlow implementation
of a layer changes (and thus changes the graph), regenerate using the command:
$ python3 layer_test.py -regen
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import resnet_model
from official.utils.testing import reference_data
DATA_FORMAT = "channels_last" # CPU instructions often preclude channels_first
BATCH_SIZE = 32
BLOCK_TESTS = [
dict(bottleneck=True, projection=True, version=1, width=8, channels=4),
dict(bottleneck=True, projection=True, version=2, width=8, channels=4),
dict(bottleneck=True, projection=False, version=1, width=8, channels=4),
dict(bottleneck=True, projection=False, version=2, width=8, channels=4),
dict(bottleneck=False, projection=True, version=1, width=8, channels=4),
dict(bottleneck=False, projection=True, version=2, width=8, channels=4),
dict(bottleneck=False, projection=False, version=1, width=8, channels=4),
dict(bottleneck=False, projection=False, version=2, width=8, channels=4),
]
class BaseTest(reference_data.BaseTest):
"""Tests for core ResNet layers."""
@property
def test_name(self):
return "resnet"
def _batch_norm_ops(self, test=False):
name = "batch_norm"
g = tf.Graph()
with g.as_default():
tf.set_random_seed(self.name_to_seed(name))
input_tensor = tf.get_variable(
"input_tensor", dtype=tf.float32,
initializer=tf.random_uniform((32, 16, 16, 3), maxval=1)
)
layer = resnet_model.batch_norm(
inputs=input_tensor, data_format=DATA_FORMAT, training=True)
self._save_or_test_ops(
name=name, graph=g, ops_to_eval=[input_tensor, layer], test=test,
correctness_function=self.default_correctness_function
)
def make_projection(self, filters_out, strides, data_format):
"""1D convolution with stride projector.
Args:
filters_out: Number of filters in the projection.
strides: Stride length for convolution.
data_format: channels_first or channels_last
Returns:
A CNN projector function with kernel_size 1.
"""
def projection_shortcut(inputs):
return resnet_model.conv2d_fixed_padding(
inputs=inputs, filters=filters_out, kernel_size=1, strides=strides,
data_format=data_format)
return projection_shortcut
def _resnet_block_ops(self, test, batch_size, bottleneck, projection,
version, width, channels):
"""Test whether resnet block construction has changed.
Args:
test: Whether or not to run as a test case.
batch_size: Number of points in the fake image. This is needed due to
batch normalization.
bottleneck: Whether or not to use bottleneck layers.
projection: Whether or not to project the input.
version: Which version of ResNet to test.
width: The width of the fake image.
channels: The number of channels in the fake image.
"""
name = "batch-size-{}_{}{}_version-{}_width-{}_channels-{}".format(
batch_size,
"bottleneck" if bottleneck else "building",
"_projection" if projection else "",
version,
width,
channels
)
if version == 1:
block_fn = resnet_model._building_block_v1
if bottleneck:
block_fn = resnet_model._bottleneck_block_v1
else:
block_fn = resnet_model._building_block_v2
if bottleneck:
block_fn = resnet_model._bottleneck_block_v2
g = tf.Graph()
with g.as_default():
tf.set_random_seed(self.name_to_seed(name))
strides = 1
channels_out = channels
projection_shortcut = None
if projection:
strides = 2
channels_out *= strides
projection_shortcut = self.make_projection(
filters_out=channels_out, strides=strides, data_format=DATA_FORMAT)
filters = channels_out
if bottleneck:
filters = channels_out // 4
input_tensor = tf.get_variable(
"input_tensor", dtype=tf.float32,
initializer=tf.random_uniform((batch_size, width, width, channels),
maxval=1)
)
layer = block_fn(inputs=input_tensor, filters=filters, training=True,
projection_shortcut=projection_shortcut, strides=strides,
data_format=DATA_FORMAT)
self._save_or_test_ops(
name=name, graph=g, ops_to_eval=[input_tensor, layer], test=test,
correctness_function=self.default_correctness_function
)
def test_batch_norm(self):
self._batch_norm_ops(test=True)
def test_block_0(self):
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[0])
def test_block_1(self):
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[1])
def test_block_2(self):
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[2])
def test_block_3(self):
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[3])
def test_block_4(self):
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[4])
def test_block_5(self):
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[5])
def test_block_6(self):
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[6])
def test_block_7(self):
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[7])
def regenerate(self):
"""Create reference data files for ResNet layer tests."""
self._batch_norm_ops(test=False)
for block_params in BLOCK_TESTS:
self._resnet_block_ops(test=False, batch_size=BATCH_SIZE, **block_params)
if __name__ == "__main__":
reference_data.main(argv=sys.argv, test_class=BaseTest)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow testing subclass to automate numerical testing.
Reference tests determine when behavior deviates from some "gold standard," and
are useful for determining when layer definitions have changed without
performing full regression testing, which is generally prohibitive. This class
handles the symbolic graph comparison as well as loading weights to avoid
relying on random number generation, which can change.
The tests performed by this class are:
1) Compare a generated graph against a reference graph. Differences are not
necessarily fatal.
2) Attempt to load known weights for the graph. If this step succeeds but
changes are present in the graph, a warning is issued but does not raise
an exception.
3) Perform a calculation and compare the result to a reference value.
This class also provides a method to generate reference data.
Note:
The test class is responsible for fixing the random seed during graph
definition. A convenience method name_to_seed() is provided to make this
process easier.
The test class should also define a .regenerate() class method which (usually)
just calls the op definition function with test=False for all relevant tests.
A concise example of this class in action is provided in:
official/utils/testing/reference_data_test.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import hashlib
import json
import os
import shutil
import sys
import numpy as np
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
class BaseTest(tf.test.TestCase):
"""TestCase subclass for performing reference data tests."""
def regenerate(self):
"""Subclasses should override this function to generate a new reference."""
raise NotImplementedError
@property
def test_name(self):
"""Subclass should define its own name."""
raise NotImplementedError
@property
def data_root(self):
"""Use the subclass directory rather than the parent directory.
Returns:
The path prefix for reference data.
"""
return os.path.join(os.path.split(
os.path.abspath(__file__))[0], "reference_data", self.test_name)
ckpt_prefix = "model.ckpt"
@staticmethod
def name_to_seed(name):
"""Convert a string into a 32 bit integer.
This function allows test cases to easily generate random fixed seeds by
hashing the name of the test. The hash string is in hex rather than base 10
which is why there is a 16 in the int call, and the modulo projects the
seed from a 128 bit int to 32 bits for readability.
Args:
name: A string containing the name of a test.
Returns:
A pseudo-random 32 bit integer derived from name.
"""
seed = hashlib.md5(name.encode("utf-8")).hexdigest()
return int(seed, 16) % (2**32 - 1)
@staticmethod
def common_tensor_properties(input_array):
"""Convenience function for matrix testing.
In tests we wish to determine whether a result has changed. However storing
an entire n-dimensional array is impractical. A better approach is to
calculate several values from that array and test that those derived values
are unchanged. The properties themselves are arbitrary and should be chosen
to be good proxies for a full equality test.
Args:
input_array: A numpy array from which key values are extracted.
Returns:
A list of values derived from the input_array for equality tests.
"""
output = list(input_array.shape)
flat_array = input_array.flatten()
output.extend([float(i) for i in
[flat_array[0], flat_array[-1], np.sum(flat_array)]])
return output
def default_correctness_function(self, *args):
"""Returns a vector with the concatenation of common properties.
This function simply calls common_tensor_properties() for every element.
It is useful as it allows one to easily construct tests of layers without
having to worry about the details of result checking.
Args:
*args: A list of numpy arrays corresponding to tensors which have been
evaluated.
Returns:
A list of values containing properties for every element in args.
"""
output = []
for arg in args:
output.extend(self.common_tensor_properties(arg))
return output
def _construct_and_save_reference_files(
self, name, graph, ops_to_eval, correctness_function):
"""Save reference data files.
Constructs a serialized graph_def, layer weights, and computation results.
It then saves them to files which are read at test time.
Args:
name: String defining the run. This will be used to define folder names
and will be used for random seed construction.
graph: The graph in which the test is conducted.
ops_to_eval: Ops which the user wishes to be evaluated under a controlled
session.
correctness_function: This function accepts the evaluated results of
ops_to_eval, and returns a list of values. This list must be JSON
serializable; in particular it is up to the user to convert numpy
dtypes into builtin dtypes.
"""
data_dir = os.path.join(self.data_root, name)
# Make sure there is a clean space for results.
if os.path.exists(data_dir):
shutil.rmtree(data_dir)
os.makedirs(data_dir)
# Serialize graph for comparison.
graph_bytes = graph.as_graph_def().SerializeToString()
expected_file = os.path.join(data_dir, "expected_graph")
with open(expected_file, "wb") as f:
f.write(graph_bytes)
with graph.as_default():
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with self.test_session(graph=graph) as sess:
sess.run(init)
saver.save(sess=sess, save_path=os.path.join(data_dir, self.ckpt_prefix))
# These files are not needed for this test.
os.remove(os.path.join(data_dir, "checkpoint"))
os.remove(os.path.join(data_dir, self.ckpt_prefix + ".meta"))
# ops are evaluated even if there is no correctness function to ensure
# that they can be evaluated.
eval_results = [op.eval() for op in ops_to_eval]
if correctness_function is not None:
results = correctness_function(*eval_results)
with open(os.path.join(data_dir, "results.json"), "wt") as f:
json.dump(results, f)
with open(os.path.join(data_dir, "tf_version.json"), "wt") as f:
json.dump([tf.VERSION, tf.GIT_VERSION], f)
def _evaluate_test_case(self, name, graph, ops_to_eval, correctness_function):
"""Determine if a graph agrees with the reference data.
Args:
name: String defining the run. This will be used to define folder names
and will be used for random seed construction.
graph: The graph in which the test is conducted.
ops_to_eval: Ops which the user wishes to be evaluated under a controlled
session.
correctness_function: This function accepts the evaluated results of
ops_to_eval, and returns a list of values. This list must be JSON
serializable; in particular it is up to the user to convert numpy
dtypes into builtin dtypes.
"""
data_dir = os.path.join(self.data_root, name)
# Serialize graph for comparison.
graph_bytes = graph.as_graph_def().SerializeToString()
expected_file = os.path.join(data_dir, "expected_graph")
with open(expected_file, "rb") as f:
expected_graph_bytes = f.read()
# The serialization is non-deterministic byte-for-byte. Instead there is
# a utility which evaluates the semantics of the two graphs to test for
# equality. This has the added benefit of providing some information on
# what changed.
# Note: The summary only show the first difference detected. It is not
# an exhaustive summary of differences.
differences = pywrap_tensorflow.EqualGraphDefWrapper(
graph_bytes, expected_graph_bytes).decode("utf-8")
with graph.as_default():
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with open(os.path.join(data_dir, "tf_version.json"), "rt") as f:
tf_version_reference, tf_git_version_reference = json.load(f) # pylint: disable=unpacking-non-sequence
tf_version_comparison = ""
if tf.GIT_VERSION != tf_git_version_reference:
tf_version_comparison = (
"Test was built using: {} (git = {})\n"
"Local TensorFlow version: {} (git = {})"
.format(tf_version_reference, tf_git_version_reference,
tf.VERSION, tf.GIT_VERSION)
)
with self.test_session(graph=graph) as sess:
sess.run(init)
try:
saver.restore(sess=sess, save_path=os.path.join(
data_dir, self.ckpt_prefix))
if differences:
tf.logging.warn(
"The provided graph is different than expected:\n {}\n"
"However the weights were still able to be loaded.\n{}".format(
differences, tf_version_comparison)
)
except: # pylint: disable=bare-except
raise self.failureException(
"Weight load failed. Graph comparison:\n {}{}"
.format(differences, tf_version_comparison))
eval_results = [op.eval() for op in ops_to_eval]
if correctness_function is not None:
results = correctness_function(*eval_results)
with open(os.path.join(data_dir, "results.json"), "rt") as f:
expected_results = json.load(f)
self.assertAllClose(results, expected_results)
def _save_or_test_ops(self, name, graph, ops_to_eval=None, test=True,
correctness_function=None):
"""Utility function to automate repeated work of graph checking and saving.
The philosophy of this function is that the user need only define ops on
a graph and specify which results should be validated. The actual work of
managing snapshots and calculating results should be automated away.
Args:
name: String defining the run. This will be used to define folder names
and will be used for random seed construction.
graph: The graph in which the test is conducted.
ops_to_eval: Ops which the user wishes to be evaluated under a controlled
session.
test: Boolean. If True this function will test graph correctness, load
weights, and compute numerical values. If False the necessary test data
will be generated and saved.
correctness_function: This function accepts the evaluated results of
ops_to_eval, and returns a list of values. This list must be JSON
serializable; in particular it is up to the user to convert numpy
dtypes into builtin dtypes.
"""
ops_to_eval = ops_to_eval or []
if test:
try:
self._evaluate_test_case(
name=name, graph=graph, ops_to_eval=ops_to_eval,
correctness_function=correctness_function
)
except:
tf.logging.error("Failed unittest {}".format(name))
raise
else:
self._construct_and_save_reference_files(
name=name, graph=graph, ops_to_eval=ops_to_eval,
correctness_function=correctness_function
)
class ReferenceDataActionParser(argparse.ArgumentParser):
"""Minimal arg parser so that test regeneration can be called from the CLI."""
def __init__(self):
super(ReferenceDataActionParser, self).__init__()
self.add_argument(
"--regenerate", "-regen",
action="store_true",
help="Enable this flag to regenerate test data. If not set unit tests"
"will be run."
)
def main(argv, test_class):
"""Simple switch function to allow test regeneration from the CLI."""
flags = ReferenceDataActionParser().parse_args(argv[1:])
if flags.regenerate:
if sys.version_info[0] == 2:
raise NameError("\nPython2 unittest does not support being run as a "
"standalone class.\nAs a result tests must be "
"regenerated using Python3.\n"
"Tests can be run under 2 or 3.")
test_class().regenerate()
else:
tf.test.main()
[1, 1, 0.4701630473136902, 0.4701630473136902, 0.4701630473136902]
\ No newline at end of file
["1.8.0-dev20180325", "v1.7.0-rc1-750-g6c1737e6c8"]
\ No newline at end of file
["1.8.0-dev20180325", "v1.7.0-rc1-750-g6c1737e6c8"]
\ No newline at end of file
[32, 8, 8, 4, 0.08920872211456299, 0.8918969631195068, 4064.7060546875, 32, 4, 4, 8, 0.0, 0.8524793982505798, 2294.368896484375]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment