Unverified Commit 0cceabfc authored by Yiming Shi's avatar Yiming Shi Committed by GitHub
Browse files

Merge branch 'master' into move_to_keraslayers_fasterrcnn_fpn_keras_feature_extractor

parents 17821c0d 39ee0ac9
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hook that counts examples per second every N steps or seconds."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.r1.utils.logs import logger
class ExamplesPerSecondHook(tf.estimator.SessionRunHook):
"""Hook to print out examples per second.
Total time is tracked and then divided by the total number of steps
to get the average step time and then batch_size is used to determine
the running average of examples per second. The examples per second for the
most recent interval is also logged.
"""
def __init__(self,
batch_size,
every_n_steps=None,
every_n_secs=None,
warm_steps=0,
metric_logger=None):
"""Initializer for ExamplesPerSecondHook.
Args:
batch_size: Total batch size across all workers used to calculate
examples/second from global time.
every_n_steps: Log stats every n steps.
every_n_secs: Log stats every n seconds. Exactly one of the
`every_n_steps` or `every_n_secs` should be set.
warm_steps: The number of steps to be skipped before logging and running
average calculation. warm_steps steps refers to global steps across all
workers, not on each worker
metric_logger: instance of `BenchmarkLogger`, the benchmark logger that
hook should use to write the log. If None, BaseBenchmarkLogger will
be used.
Raises:
ValueError: if neither `every_n_steps` or `every_n_secs` is set, or
both are set.
"""
if (every_n_steps is None) == (every_n_secs is None):
raise ValueError("exactly one of every_n_steps"
" and every_n_secs should be provided.")
self._logger = metric_logger or logger.BaseBenchmarkLogger()
self._timer = tf.estimator.SecondOrStepTimer(
every_steps=every_n_steps, every_secs=every_n_secs)
self._step_train_time = 0
self._total_steps = 0
self._batch_size = batch_size
self._warm_steps = warm_steps
# List of examples per second logged every_n_steps.
self.current_examples_per_sec_list = []
def begin(self):
"""Called once before using the session to check global step."""
self._global_step_tensor = tf.compat.v1.train.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use StepCounterHook.")
def before_run(self, run_context): # pylint: disable=unused-argument
"""Called before each call to run().
Args:
run_context: A SessionRunContext object.
Returns:
A SessionRunArgs object or None if never triggered.
"""
return tf.estimator.SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values): # pylint: disable=unused-argument
"""Called after each call to run().
Args:
run_context: A SessionRunContext object.
run_values: A SessionRunValues object.
"""
global_step = run_values.results
if self._timer.should_trigger_for_step(
global_step) and global_step > self._warm_steps:
elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
global_step)
if elapsed_time is not None:
self._step_train_time += elapsed_time
self._total_steps += elapsed_steps
# average examples per second is based on the total (accumulative)
# training steps and training time so far
average_examples_per_sec = self._batch_size * (
self._total_steps / self._step_train_time)
# current examples per second is based on the elapsed training steps
# and training time per batch
current_examples_per_sec = self._batch_size * (
elapsed_steps / elapsed_time)
# Logs entries to be read from hook during or after run.
self.current_examples_per_sec_list.append(current_examples_per_sec)
self._logger.log_metric(
"average_examples_per_sec", average_examples_per_sec,
global_step=global_step)
self._logger.log_metric(
"current_examples_per_sec", current_examples_per_sec,
global_step=global_step)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hooks helper to return a list of TensorFlow hooks for training by name.
More hooks can be added to this set. To add a new hook, 1) add the new hook to
the registry in HOOKS, 2) add a corresponding function that parses out necessary
parameters.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from absl import logging
from official.r1.utils.logs import hooks
from official.r1.utils.logs import logger
from official.r1.utils.logs import metric_hook
_TENSORS_TO_LOG = dict((x, x) for x in ['learning_rate',
'cross_entropy',
'train_accuracy'])
def get_train_hooks(name_list, use_tpu=False, **kwargs):
"""Factory for getting a list of TensorFlow hooks for training by name.
Args:
name_list: a list of strings to name desired hook classes. Allowed:
LoggingTensorHook, ProfilerHook, ExamplesPerSecondHook, which are defined
as keys in HOOKS
use_tpu: Boolean of whether computation occurs on a TPU. This will disable
hooks altogether.
**kwargs: a dictionary of arguments to the hooks.
Returns:
list of instantiated hooks, ready to be used in a classifier.train call.
Raises:
ValueError: if an unrecognized name is passed.
"""
if not name_list:
return []
if use_tpu:
logging.warning(
'hooks_helper received name_list `%s`, but a '
'TPU is specified. No hooks will be used.', name_list)
return []
train_hooks = []
for name in name_list:
hook_name = HOOKS.get(name.strip().lower())
if hook_name is None:
raise ValueError('Unrecognized training hook requested: {}'.format(name))
else:
train_hooks.append(hook_name(**kwargs))
return train_hooks
def get_logging_tensor_hook(every_n_iter=100, tensors_to_log=None, **kwargs): # pylint: disable=unused-argument
"""Function to get LoggingTensorHook.
Args:
every_n_iter: `int`, print the values of `tensors` once every N local
steps taken on the current worker.
tensors_to_log: List of tensor names or dictionary mapping labels to tensor
names. If not set, log _TENSORS_TO_LOG by default.
**kwargs: a dictionary of arguments to LoggingTensorHook.
Returns:
Returns a LoggingTensorHook with a standard set of tensors that will be
printed to stdout.
"""
if tensors_to_log is None:
tensors_to_log = _TENSORS_TO_LOG
return tf.estimator.LoggingTensorHook(
tensors=tensors_to_log,
every_n_iter=every_n_iter)
def get_profiler_hook(model_dir, save_steps=1000, **kwargs): # pylint: disable=unused-argument
"""Function to get ProfilerHook.
Args:
model_dir: The directory to save the profile traces to.
save_steps: `int`, print profile traces every N steps.
**kwargs: a dictionary of arguments to ProfilerHook.
Returns:
Returns a ProfilerHook that writes out timelines that can be loaded into
profiling tools like chrome://tracing.
"""
return tf.estimator.ProfilerHook(save_steps=save_steps, output_dir=model_dir)
def get_examples_per_second_hook(every_n_steps=100,
batch_size=128,
warm_steps=5,
**kwargs): # pylint: disable=unused-argument
"""Function to get ExamplesPerSecondHook.
Args:
every_n_steps: `int`, print current and average examples per second every
N steps.
batch_size: `int`, total batch size used to calculate examples/second from
global time.
warm_steps: skip this number of steps before logging and running average.
**kwargs: a dictionary of arguments to ExamplesPerSecondHook.
Returns:
Returns a ProfilerHook that writes out timelines that can be loaded into
profiling tools like chrome://tracing.
"""
return hooks.ExamplesPerSecondHook(
batch_size=batch_size, every_n_steps=every_n_steps,
warm_steps=warm_steps, metric_logger=logger.get_benchmark_logger())
def get_logging_metric_hook(tensors_to_log=None,
every_n_secs=600,
**kwargs): # pylint: disable=unused-argument
"""Function to get LoggingMetricHook.
Args:
tensors_to_log: List of tensor names or dictionary mapping labels to tensor
names. If not set, log _TENSORS_TO_LOG by default.
every_n_secs: `int`, the frequency for logging the metric. Default to every
10 mins.
**kwargs: a dictionary of arguments.
Returns:
Returns a LoggingMetricHook that saves tensor values in a JSON format.
"""
if tensors_to_log is None:
tensors_to_log = _TENSORS_TO_LOG
return metric_hook.LoggingMetricHook(
tensors=tensors_to_log,
metric_logger=logger.get_benchmark_logger(),
every_n_secs=every_n_secs)
def get_step_counter_hook(**kwargs):
"""Function to get StepCounterHook."""
del kwargs
return tf.estimator.StepCounterHook()
# A dictionary to map one hook name and its corresponding function
HOOKS = {
'loggingtensorhook': get_logging_tensor_hook,
'profilerhook': get_profiler_hook,
'examplespersecondhook': get_examples_per_second_hook,
'loggingmetrichook': get_logging_metric_hook,
'stepcounterhook': get_step_counter_hook
}
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for hooks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from absl import logging
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.r1.utils.logs import hooks
from official.r1.utils.logs import mock_lib
logging.set_verbosity(logging.DEBUG)
class ExamplesPerSecondHookTest(tf.test.TestCase):
"""Tests for the ExamplesPerSecondHook.
In the test, we explicitly run global_step tensor after train_op in order to
keep the global_step value and the train_op (which increase the glboal_step
by 1) consistent. This is to correct the discrepancies in reported global_step
value when running on GPUs.
"""
def setUp(self):
"""Mock out logging calls to verify if correct info is being monitored."""
self._logger = mock_lib.MockBenchmarkLogger()
self.graph = tf.Graph()
with self.graph.as_default():
tf.compat.v1.train.create_global_step()
self.train_op = tf.compat.v1.assign_add(
tf.compat.v1.train.get_global_step(), 1)
self.global_step = tf.compat.v1.train.get_global_step()
def test_raise_in_both_secs_and_steps(self):
with self.assertRaises(ValueError):
hooks.ExamplesPerSecondHook(
batch_size=256,
every_n_steps=10,
every_n_secs=20,
metric_logger=self._logger)
def test_raise_in_none_secs_and_steps(self):
with self.assertRaises(ValueError):
hooks.ExamplesPerSecondHook(
batch_size=256,
every_n_steps=None,
every_n_secs=None,
metric_logger=self._logger)
def _validate_log_every_n_steps(self, every_n_steps, warm_steps):
hook = hooks.ExamplesPerSecondHook(
batch_size=256,
every_n_steps=every_n_steps,
warm_steps=warm_steps,
metric_logger=self._logger)
with tf.compat.v1.train.MonitoredSession(
tf.compat.v1.train.ChiefSessionCreator(), [hook]) as mon_sess:
for _ in range(every_n_steps):
# Explicitly run global_step after train_op to get the accurate
# global_step value
mon_sess.run(self.train_op)
mon_sess.run(self.global_step)
# Nothing should be in the list yet
self.assertFalse(self._logger.logged_metric)
mon_sess.run(self.train_op)
global_step_val = mon_sess.run(self.global_step)
if global_step_val > warm_steps:
self._assert_metrics()
else:
# Nothing should be in the list yet
self.assertFalse(self._logger.logged_metric)
# Add additional run to verify proper reset when called multiple times.
prev_log_len = len(self._logger.logged_metric)
mon_sess.run(self.train_op)
global_step_val = mon_sess.run(self.global_step)
if every_n_steps == 1 and global_step_val > warm_steps:
# Each time, we log two additional metrics. Did exactly 2 get added?
self.assertEqual(len(self._logger.logged_metric), prev_log_len + 2)
else:
# No change in the size of the metric list.
self.assertEqual(len(self._logger.logged_metric), prev_log_len)
def test_examples_per_sec_every_1_steps(self):
with self.graph.as_default():
self._validate_log_every_n_steps(1, 0)
def test_examples_per_sec_every_5_steps(self):
with self.graph.as_default():
self._validate_log_every_n_steps(5, 0)
def test_examples_per_sec_every_1_steps_with_warm_steps(self):
with self.graph.as_default():
self._validate_log_every_n_steps(1, 10)
def test_examples_per_sec_every_5_steps_with_warm_steps(self):
with self.graph.as_default():
self._validate_log_every_n_steps(5, 10)
def _validate_log_every_n_secs(self, every_n_secs):
hook = hooks.ExamplesPerSecondHook(
batch_size=256,
every_n_steps=None,
every_n_secs=every_n_secs,
metric_logger=self._logger)
with tf.compat.v1.train.MonitoredSession(
tf.compat.v1.train.ChiefSessionCreator(), [hook]) as mon_sess:
# Explicitly run global_step after train_op to get the accurate
# global_step value
mon_sess.run(self.train_op)
mon_sess.run(self.global_step)
# Nothing should be in the list yet
self.assertFalse(self._logger.logged_metric)
time.sleep(every_n_secs)
mon_sess.run(self.train_op)
mon_sess.run(self.global_step)
self._assert_metrics()
def test_examples_per_sec_every_1_secs(self):
with self.graph.as_default():
self._validate_log_every_n_secs(1)
def test_examples_per_sec_every_5_secs(self):
with self.graph.as_default():
self._validate_log_every_n_secs(5)
def _assert_metrics(self):
metrics = self._logger.logged_metric
self.assertEqual(metrics[-2]["name"], "average_examples_per_sec")
self.assertEqual(metrics[-1]["name"], "current_examples_per_sec")
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Logging utilities for benchmark.
For collecting local environment metrics like CPU and memory, certain python
packages need be installed. See README for details.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import datetime
import json
import numbers
import os
import threading
import uuid
from absl import flags
from absl import logging
from six.moves import _thread as thread
import tensorflow as tf
from tensorflow.python.client import device_lib
from official.r1.utils.logs import cloud_lib
METRIC_LOG_FILE_NAME = "metric.log"
BENCHMARK_RUN_LOG_FILE_NAME = "benchmark_run.log"
_DATE_TIME_FORMAT_PATTERN = "%Y-%m-%dT%H:%M:%S.%fZ"
GCP_TEST_ENV = "GCP"
RUN_STATUS_SUCCESS = "success"
RUN_STATUS_FAILURE = "failure"
RUN_STATUS_RUNNING = "running"
FLAGS = flags.FLAGS
# Don't use it directly. Use get_benchmark_logger to access a logger.
_benchmark_logger = None
_logger_lock = threading.Lock()
def config_benchmark_logger(flag_obj=None):
"""Config the global benchmark logger."""
_logger_lock.acquire()
try:
global _benchmark_logger
if not flag_obj:
flag_obj = FLAGS
if (not hasattr(flag_obj, "benchmark_logger_type") or
flag_obj.benchmark_logger_type == "BaseBenchmarkLogger"):
_benchmark_logger = BaseBenchmarkLogger()
elif flag_obj.benchmark_logger_type == "BenchmarkFileLogger":
_benchmark_logger = BenchmarkFileLogger(flag_obj.benchmark_log_dir)
else:
raise ValueError("Unrecognized benchmark_logger_type: %s"
% flag_obj.benchmark_logger_type)
finally:
_logger_lock.release()
return _benchmark_logger
def get_benchmark_logger():
if not _benchmark_logger:
config_benchmark_logger()
return _benchmark_logger
@contextlib.contextmanager
def benchmark_context(flag_obj):
"""Context of benchmark, which will update status of the run accordingly."""
benchmark_logger = config_benchmark_logger(flag_obj)
try:
yield
benchmark_logger.on_finish(RUN_STATUS_SUCCESS)
except Exception: # pylint: disable=broad-except
# Catch all the exception, update the run status to be failure, and re-raise
benchmark_logger.on_finish(RUN_STATUS_FAILURE)
raise
class BaseBenchmarkLogger(object):
"""Class to log the benchmark information to STDOUT."""
def log_evaluation_result(self, eval_results):
"""Log the evaluation result.
The evaluate result is a dictionary that contains metrics defined in
model_fn. It also contains a entry for global_step which contains the value
of the global step when evaluation was performed.
Args:
eval_results: dict, the result of evaluate.
"""
if not isinstance(eval_results, dict):
logging.warning("eval_results should be dictionary for logging. Got %s",
type(eval_results))
return
global_step = eval_results[tf.compat.v1.GraphKeys.GLOBAL_STEP]
for key in sorted(eval_results):
if key != tf.compat.v1.GraphKeys.GLOBAL_STEP:
self.log_metric(key, eval_results[key], global_step=global_step)
def log_metric(self, name, value, unit=None, global_step=None, extras=None):
"""Log the benchmark metric information to local file.
Currently the logging is done in a synchronized way. This should be updated
to log asynchronously.
Args:
name: string, the name of the metric to log.
value: number, the value of the metric. The value will not be logged if it
is not a number type.
unit: string, the unit of the metric, E.g "image per second".
global_step: int, the global_step when the metric is logged.
extras: map of string:string, the extra information about the metric.
"""
metric = _process_metric_to_json(name, value, unit, global_step, extras)
if metric:
logging.info("Benchmark metric: %s", metric)
def log_run_info(self, model_name, dataset_name, run_params, test_id=None):
logging.info(
"Benchmark run: %s",
_gather_run_info(model_name, dataset_name, run_params, test_id))
def on_finish(self, status):
pass
class BenchmarkFileLogger(BaseBenchmarkLogger):
"""Class to log the benchmark information to local disk."""
def __init__(self, logging_dir):
super(BenchmarkFileLogger, self).__init__()
self._logging_dir = logging_dir
if not tf.io.gfile.isdir(self._logging_dir):
tf.io.gfile.makedirs(self._logging_dir)
self._metric_file_handler = tf.io.gfile.GFile(
os.path.join(self._logging_dir, METRIC_LOG_FILE_NAME), "a")
def log_metric(self, name, value, unit=None, global_step=None, extras=None):
"""Log the benchmark metric information to local file.
Currently the logging is done in a synchronized way. This should be updated
to log asynchronously.
Args:
name: string, the name of the metric to log.
value: number, the value of the metric. The value will not be logged if it
is not a number type.
unit: string, the unit of the metric, E.g "image per second".
global_step: int, the global_step when the metric is logged.
extras: map of string:string, the extra information about the metric.
"""
metric = _process_metric_to_json(name, value, unit, global_step, extras)
if metric:
try:
json.dump(metric, self._metric_file_handler)
self._metric_file_handler.write("\n")
self._metric_file_handler.flush()
except (TypeError, ValueError) as e:
logging.warning(
"Failed to dump metric to log file: name %s, value %s, error %s",
name, value, e)
def log_run_info(self, model_name, dataset_name, run_params, test_id=None):
"""Collect most of the TF runtime information for the local env.
The schema of the run info follows official/benchmark/datastore/schema.
Args:
model_name: string, the name of the model.
dataset_name: string, the name of dataset for training and evaluation.
run_params: dict, the dictionary of parameters for the run, it could
include hyperparameters or other params that are important for the run.
test_id: string, the unique name of the test run by the combination of key
parameters, eg batch size, num of GPU. It is hardware independent.
"""
run_info = _gather_run_info(model_name, dataset_name, run_params, test_id)
with tf.io.gfile.GFile(os.path.join(
self._logging_dir, BENCHMARK_RUN_LOG_FILE_NAME), "w") as f:
try:
json.dump(run_info, f)
f.write("\n")
except (TypeError, ValueError) as e:
logging.warning("Failed to dump benchmark run info to log file: %s", e)
def on_finish(self, status):
self._metric_file_handler.flush()
self._metric_file_handler.close()
def _gather_run_info(model_name, dataset_name, run_params, test_id):
"""Collect the benchmark run information for the local environment."""
run_info = {
"model_name": model_name,
"dataset": {"name": dataset_name},
"machine_config": {},
"test_id": test_id,
"run_date": datetime.datetime.utcnow().strftime(
_DATE_TIME_FORMAT_PATTERN)}
_collect_tensorflow_info(run_info)
_collect_tensorflow_environment_variables(run_info)
_collect_run_params(run_info, run_params)
_collect_memory_info(run_info)
_collect_test_environment(run_info)
return run_info
def _process_metric_to_json(
name, value, unit=None, global_step=None, extras=None):
"""Validate the metric data and generate JSON for insert."""
if not isinstance(value, numbers.Number):
logging.warning("Metric value to log should be a number. Got %s",
type(value))
return None
extras = _convert_to_json_dict(extras)
return {
"name": name,
"value": float(value),
"unit": unit,
"global_step": global_step,
"timestamp": datetime.datetime.utcnow().strftime(
_DATE_TIME_FORMAT_PATTERN),
"extras": extras}
def _collect_tensorflow_info(run_info):
run_info["tensorflow_version"] = {
"version": tf.version.VERSION, "git_hash": tf.version.GIT_VERSION}
def _collect_run_params(run_info, run_params):
"""Log the parameter information for the benchmark run."""
def process_param(name, value):
type_check = {
str: {"name": name, "string_value": value},
int: {"name": name, "long_value": value},
bool: {"name": name, "bool_value": str(value)},
float: {"name": name, "float_value": value},
}
return type_check.get(type(value),
{"name": name, "string_value": str(value)})
if run_params:
run_info["run_parameters"] = [
process_param(k, v) for k, v in sorted(run_params.items())]
def _collect_tensorflow_environment_variables(run_info):
run_info["tensorflow_environment_variables"] = [
{"name": k, "value": v}
for k, v in sorted(os.environ.items()) if k.startswith("TF_")]
def _collect_memory_info(run_info):
try:
# Note: psutil is not installed in the TensorFlow OSS tree.
# It is installable via pip.
import psutil # pylint: disable=g-import-not-at-top
vmem = psutil.virtual_memory()
run_info["machine_config"]["memory_total"] = vmem.total
run_info["machine_config"]["memory_available"] = vmem.available
except ImportError:
logging.warn("'psutil' not imported. Memory info will not be logged.")
def _collect_test_environment(run_info):
"""Detect the local environment, eg GCE, AWS or DGX, etc."""
if cloud_lib.on_gcp():
run_info["test_environment"] = GCP_TEST_ENV
# TODO(scottzhu): Add more testing env detection for other platform
def _parse_gpu_model(physical_device_desc):
# Assume all the GPU connected are same model
for kv in physical_device_desc.split(","):
k, _, v = kv.partition(":")
if k.strip() == "name":
return v.strip()
return None
def _convert_to_json_dict(input_dict):
if input_dict:
return [{"name": k, "value": v} for k, v in sorted(input_dict.items())]
else:
return []
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for benchmark logger."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import tempfile
import time
import unittest
from absl import logging
from absl.testing import flagsaver
import tensorflow as tf
from official.r1.utils.logs import logger
from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils
class BenchmarkLoggerTest(tf.test.TestCase):
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(BenchmarkLoggerTest, cls).setUpClass()
flags_core.define_benchmark()
def test_get_default_benchmark_logger(self):
with flagsaver.flagsaver(benchmark_logger_type="foo"):
self.assertIsInstance(logger.get_benchmark_logger(),
logger.BaseBenchmarkLogger)
def test_config_base_benchmark_logger(self):
with flagsaver.flagsaver(benchmark_logger_type="BaseBenchmarkLogger"):
logger.config_benchmark_logger()
self.assertIsInstance(logger.get_benchmark_logger(),
logger.BaseBenchmarkLogger)
def test_config_benchmark_file_logger(self):
# Set the benchmark_log_dir first since the benchmark_logger_type will need
# the value to be set when it does the validation.
with flagsaver.flagsaver(benchmark_log_dir="/tmp"):
with flagsaver.flagsaver(benchmark_logger_type="BenchmarkFileLogger"):
logger.config_benchmark_logger()
self.assertIsInstance(logger.get_benchmark_logger(),
logger.BenchmarkFileLogger)
class BaseBenchmarkLoggerTest(tf.test.TestCase):
def setUp(self):
super(BaseBenchmarkLoggerTest, self).setUp()
self._actual_log = logging.info
self.logged_message = None
def mock_log(*args, **kwargs):
self.logged_message = args
self._actual_log(*args, **kwargs)
logging.info = mock_log
def tearDown(self):
super(BaseBenchmarkLoggerTest, self).tearDown()
logging.info = self._actual_log
def test_log_metric(self):
log = logger.BaseBenchmarkLogger()
log.log_metric("accuracy", 0.999, global_step=1e4, extras={"name": "value"})
expected_log_prefix = "Benchmark metric:"
self.assertRegexpMatches(str(self.logged_message), expected_log_prefix)
class BenchmarkFileLoggerTest(tf.test.TestCase):
def setUp(self):
super(BenchmarkFileLoggerTest, self).setUp()
# Avoid pulling extra env vars from test environment which affects the test
# result, eg. Kokoro test has a TF_PKG env which affect the test case
# test_collect_tensorflow_environment_variables()
self.original_environ = dict(os.environ)
os.environ.clear()
def tearDown(self):
super(BenchmarkFileLoggerTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir())
os.environ.clear()
os.environ.update(self.original_environ)
def test_create_logging_dir(self):
non_exist_temp_dir = os.path.join(self.get_temp_dir(), "unknown_dir")
self.assertFalse(tf.io.gfile.isdir(non_exist_temp_dir))
logger.BenchmarkFileLogger(non_exist_temp_dir)
self.assertTrue(tf.io.gfile.isdir(non_exist_temp_dir))
def test_log_metric(self):
log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
log = logger.BenchmarkFileLogger(log_dir)
log.log_metric("accuracy", 0.999, global_step=1e4, extras={"name": "value"})
metric_log = os.path.join(log_dir, "metric.log")
self.assertTrue(tf.io.gfile.exists(metric_log))
with tf.io.gfile.GFile(metric_log) as f:
metric = json.loads(f.readline())
self.assertEqual(metric["name"], "accuracy")
self.assertEqual(metric["value"], 0.999)
self.assertEqual(metric["unit"], None)
self.assertEqual(metric["global_step"], 1e4)
self.assertEqual(metric["extras"], [{"name": "name", "value": "value"}])
def test_log_multiple_metrics(self):
log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
log = logger.BenchmarkFileLogger(log_dir)
log.log_metric("accuracy", 0.999, global_step=1e4, extras={"name": "value"})
log.log_metric("loss", 0.02, global_step=1e4)
metric_log = os.path.join(log_dir, "metric.log")
self.assertTrue(tf.io.gfile.exists(metric_log))
with tf.io.gfile.GFile(metric_log) as f:
accuracy = json.loads(f.readline())
self.assertEqual(accuracy["name"], "accuracy")
self.assertEqual(accuracy["value"], 0.999)
self.assertEqual(accuracy["unit"], None)
self.assertEqual(accuracy["global_step"], 1e4)
self.assertEqual(accuracy["extras"], [{"name": "name", "value": "value"}])
loss = json.loads(f.readline())
self.assertEqual(loss["name"], "loss")
self.assertEqual(loss["value"], 0.02)
self.assertEqual(loss["unit"], None)
self.assertEqual(loss["global_step"], 1e4)
self.assertEqual(loss["extras"], [])
def test_log_non_number_value(self):
log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
log = logger.BenchmarkFileLogger(log_dir)
const = tf.constant(1)
log.log_metric("accuracy", const)
metric_log = os.path.join(log_dir, "metric.log")
self.assertFalse(tf.io.gfile.exists(metric_log))
def test_log_evaluation_result(self):
eval_result = {"loss": 0.46237424,
"global_step": 207082,
"accuracy": 0.9285}
log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
log = logger.BenchmarkFileLogger(log_dir)
log.log_evaluation_result(eval_result)
metric_log = os.path.join(log_dir, "metric.log")
self.assertTrue(tf.io.gfile.exists(metric_log))
with tf.io.gfile.GFile(metric_log) as f:
accuracy = json.loads(f.readline())
self.assertEqual(accuracy["name"], "accuracy")
self.assertEqual(accuracy["value"], 0.9285)
self.assertEqual(accuracy["unit"], None)
self.assertEqual(accuracy["global_step"], 207082)
loss = json.loads(f.readline())
self.assertEqual(loss["name"], "loss")
self.assertEqual(loss["value"], 0.46237424)
self.assertEqual(loss["unit"], None)
self.assertEqual(loss["global_step"], 207082)
def test_log_evaluation_result_with_invalid_type(self):
eval_result = "{'loss': 0.46237424, 'global_step': 207082}"
log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
log = logger.BenchmarkFileLogger(log_dir)
log.log_evaluation_result(eval_result)
metric_log = os.path.join(log_dir, "metric.log")
self.assertFalse(tf.io.gfile.exists(metric_log))
def test_collect_tensorflow_info(self):
run_info = {}
logger._collect_tensorflow_info(run_info)
self.assertNotEqual(run_info["tensorflow_version"], {})
self.assertEqual(run_info["tensorflow_version"]["version"],
tf.version.VERSION)
self.assertEqual(run_info["tensorflow_version"]["git_hash"],
tf.version.GIT_VERSION)
def test_collect_run_params(self):
run_info = {}
run_parameters = {
"batch_size": 32,
"synthetic_data": True,
"train_epochs": 100.00,
"dtype": "fp16",
"resnet_size": 50,
"random_tensor": tf.constant(2.0)
}
logger._collect_run_params(run_info, run_parameters)
self.assertEqual(len(run_info["run_parameters"]), 6)
self.assertEqual(run_info["run_parameters"][0],
{"name": "batch_size", "long_value": 32})
self.assertEqual(run_info["run_parameters"][1],
{"name": "dtype", "string_value": "fp16"})
v1_tensor = {"name": "random_tensor", "string_value":
"Tensor(\"Const:0\", shape=(), dtype=float32)"}
v2_tensor = {"name": "random_tensor", "string_value":
"tf.Tensor(2.0, shape=(), dtype=float32)"}
self.assertIn(run_info["run_parameters"][2], [v1_tensor, v2_tensor])
self.assertEqual(run_info["run_parameters"][3],
{"name": "resnet_size", "long_value": 50})
self.assertEqual(run_info["run_parameters"][4],
{"name": "synthetic_data", "bool_value": "True"})
self.assertEqual(run_info["run_parameters"][5],
{"name": "train_epochs", "float_value": 100.00})
def test_collect_tensorflow_environment_variables(self):
os.environ["TF_ENABLE_WINOGRAD_NONFUSED"] = "1"
os.environ["TF_OTHER"] = "2"
os.environ["OTHER"] = "3"
run_info = {}
logger._collect_tensorflow_environment_variables(run_info)
self.assertIsNotNone(run_info["tensorflow_environment_variables"])
expected_tf_envs = [
{"name": "TF_ENABLE_WINOGRAD_NONFUSED", "value": "1"},
{"name": "TF_OTHER", "value": "2"},
]
self.assertEqual(run_info["tensorflow_environment_variables"],
expected_tf_envs)
def test_collect_memory_info(self):
run_info = {"machine_config": {}}
logger._collect_memory_info(run_info)
self.assertIsNotNone(run_info["machine_config"]["memory_total"])
self.assertIsNotNone(run_info["machine_config"]["memory_available"])
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Session hook for logging benchmark metric."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
class LoggingMetricHook(tf.estimator.LoggingTensorHook):
"""Hook to log benchmark metric information.
This hook is very similar as tf.train.LoggingTensorHook, which logs given
tensors every N local steps, every N seconds, or at the end. The metric
information will be logged to given log_dir or via metric_logger in JSON
format, which can be consumed by data analysis pipeline later.
Note that if `at_end` is True, `tensors` should not include any tensor
whose evaluation produces a side effect such as consuming additional inputs.
"""
def __init__(self, tensors, metric_logger=None,
every_n_iter=None, every_n_secs=None, at_end=False):
"""Initializer for LoggingMetricHook.
Args:
tensors: `dict` that maps string-valued tags to tensors/tensor names,
or `iterable` of tensors/tensor names.
metric_logger: instance of `BenchmarkLogger`, the benchmark logger that
hook should use to write the log.
every_n_iter: `int`, print the values of `tensors` once every N local
steps taken on the current worker.
every_n_secs: `int` or `float`, print the values of `tensors` once every N
seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
provided.
at_end: `bool` specifying whether to print the values of `tensors` at the
end of the run.
Raises:
ValueError:
1. `every_n_iter` is non-positive, or
2. Exactly one of every_n_iter and every_n_secs should be provided.
3. Exactly one of log_dir and metric_logger should be provided.
"""
super(LoggingMetricHook, self).__init__(
tensors=tensors,
every_n_iter=every_n_iter,
every_n_secs=every_n_secs,
at_end=at_end)
if metric_logger is None:
raise ValueError("metric_logger should be provided.")
self._logger = metric_logger
def begin(self):
super(LoggingMetricHook, self).begin()
self._global_step_tensor = tf.compat.v1.train.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use LoggingMetricHook.")
if self._global_step_tensor.name not in self._current_tensors:
self._current_tensors[self._global_step_tensor.name] = (
self._global_step_tensor)
def after_run(self, unused_run_context, run_values):
# should_trigger is a internal state that populated at before_run, and it is
# using self_timer to determine whether it should trigger.
if self._should_trigger:
self._log_metric(run_values.results)
self._iter_count += 1
def end(self, session):
if self._log_at_end:
values = session.run(self._current_tensors)
self._log_metric(values)
def _log_metric(self, tensor_values):
self._timer.update_last_triggered_step(self._iter_count)
global_step = tensor_values[self._global_step_tensor.name]
# self._tag_order is populated during the init of LoggingTensorHook
for tag in self._tag_order:
self._logger.log_metric(tag, tensor_values[tag], global_step=global_step)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for metric_hook."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tempfile
import time
import tensorflow as tf # pylint: disable=g-bad-import-order
from tensorflow.python.training import monitored_session # pylint: disable=g-bad-import-order
from official.r1.utils.logs import metric_hook
from official.r1.utils.logs import mock_lib
class LoggingMetricHookTest(tf.test.TestCase):
"""Tests for LoggingMetricHook."""
def setUp(self):
super(LoggingMetricHookTest, self).setUp()
self._log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
self._logger = mock_lib.MockBenchmarkLogger()
def tearDown(self):
super(LoggingMetricHookTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir())
def test_illegal_args(self):
with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"):
metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=0)
with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"):
metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=-10)
with self.assertRaisesRegexp(ValueError, "xactly one of"):
metric_hook.LoggingMetricHook(
tensors=["t"], every_n_iter=5, every_n_secs=5)
with self.assertRaisesRegexp(ValueError, "xactly one of"):
metric_hook.LoggingMetricHook(tensors=["t"])
with self.assertRaisesRegexp(ValueError, "metric_logger"):
metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=5)
def test_print_at_end_only(self):
with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.compat.v1.train.get_or_create_global_step()
t = tf.constant(42.0, name="foo")
train_op = tf.constant(3)
hook = metric_hook.LoggingMetricHook(
tensors=[t.name], at_end=True, metric_logger=self._logger)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.compat.v1.global_variables_initializer())
for _ in range(3):
mon_sess.run(train_op)
self.assertEqual(self._logger.logged_metric, [])
hook.end(sess)
self.assertEqual(len(self._logger.logged_metric), 1)
metric = self._logger.logged_metric[0]
self.assertRegexpMatches(metric["name"], "foo")
self.assertEqual(metric["value"], 42.0)
self.assertEqual(metric["unit"], None)
self.assertEqual(metric["global_step"], 0)
def test_global_step_not_found(self):
with tf.Graph().as_default():
t = tf.constant(42.0, name="foo")
hook = metric_hook.LoggingMetricHook(
tensors=[t.name], at_end=True, metric_logger=self._logger)
with self.assertRaisesRegexp(
RuntimeError, "should be created to use LoggingMetricHook."):
hook.begin()
def test_log_tensors(self):
with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.compat.v1.train.get_or_create_global_step()
t1 = tf.constant(42.0, name="foo")
t2 = tf.constant(43.0, name="bar")
train_op = tf.constant(3)
hook = metric_hook.LoggingMetricHook(
tensors=[t1, t2], at_end=True, metric_logger=self._logger)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.compat.v1.global_variables_initializer())
for _ in range(3):
mon_sess.run(train_op)
self.assertEqual(self._logger.logged_metric, [])
hook.end(sess)
self.assertEqual(len(self._logger.logged_metric), 2)
metric1 = self._logger.logged_metric[0]
self.assertRegexpMatches(str(metric1["name"]), "foo")
self.assertEqual(metric1["value"], 42.0)
self.assertEqual(metric1["unit"], None)
self.assertEqual(metric1["global_step"], 0)
metric2 = self._logger.logged_metric[1]
self.assertRegexpMatches(str(metric2["name"]), "bar")
self.assertEqual(metric2["value"], 43.0)
self.assertEqual(metric2["unit"], None)
self.assertEqual(metric2["global_step"], 0)
def _validate_print_every_n_steps(self, sess, at_end):
t = tf.constant(42.0, name="foo")
train_op = tf.constant(3)
hook = metric_hook.LoggingMetricHook(
tensors=[t.name], every_n_iter=10, at_end=at_end,
metric_logger=self._logger)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.compat.v1.global_variables_initializer())
mon_sess.run(train_op)
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
for _ in range(3):
self._logger.logged_metric = []
for _ in range(9):
mon_sess.run(train_op)
# assertNotRegexpMatches is not supported by python 3.1 and later
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
mon_sess.run(train_op)
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
# Add additional run to verify proper reset when called multiple times.
self._logger.logged_metric = []
mon_sess.run(train_op)
# assertNotRegexpMatches is not supported by python 3.1 and later
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
self._logger.logged_metric = []
hook.end(sess)
if at_end:
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
else:
# assertNotRegexpMatches is not supported by python 3.1 and later
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
def test_print_every_n_steps(self):
with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.compat.v1.train.get_or_create_global_step()
self._validate_print_every_n_steps(sess, at_end=False)
# Verify proper reset.
self._validate_print_every_n_steps(sess, at_end=False)
def test_print_every_n_steps_and_end(self):
with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.compat.v1.train.get_or_create_global_step()
self._validate_print_every_n_steps(sess, at_end=True)
# Verify proper reset.
self._validate_print_every_n_steps(sess, at_end=True)
def _validate_print_every_n_secs(self, sess, at_end):
t = tf.constant(42.0, name="foo")
train_op = tf.constant(3)
hook = metric_hook.LoggingMetricHook(
tensors=[t.name], every_n_secs=1.0, at_end=at_end,
metric_logger=self._logger)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.compat.v1.global_variables_initializer())
mon_sess.run(train_op)
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
# assertNotRegexpMatches is not supported by python 3.1 and later
self._logger.logged_metric = []
mon_sess.run(train_op)
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
time.sleep(1.0)
self._logger.logged_metric = []
mon_sess.run(train_op)
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
self._logger.logged_metric = []
hook.end(sess)
if at_end:
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
else:
# assertNotRegexpMatches is not supported by python 3.1 and later
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
def test_print_every_n_secs(self):
with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.compat.v1.train.get_or_create_global_step()
self._validate_print_every_n_secs(sess, at_end=False)
# Verify proper reset.
self._validate_print_every_n_secs(sess, at_end=False)
def test_print_every_n_secs_and_end(self):
with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.compat.v1.train.get_or_create_global_step()
self._validate_print_every_n_secs(sess, at_end=True)
# Verify proper reset.
self._validate_print_every_n_secs(sess, at_end=True)
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for the mlperf logging utils.
MLPerf compliance logging is only desired under a limited set of circumstances.
This module is intended to keep users from needing to consider logging (or
install the module) unless they are performing mlperf runs.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import json
import os
import re
import subprocess
import sys
from absl import logging
import typing
# pylint:disable=logging-format-interpolation
_MIN_VERSION = (0, 0, 10)
_STACK_OFFSET = 2
SUDO = "sudo" if os.geteuid() else ""
# This indirection is used in docker.
DROP_CACHE_LOC = os.getenv("DROP_CACHE_LOC", "/proc/sys/vm/drop_caches")
_NCF_PREFIX = "NCF_RAW_"
# TODO(robieta): move line parsing to mlperf util
_PREFIX = r"(?:{})?:::MLPv([0-9]+).([0-9]+).([0-9]+)".format(_NCF_PREFIX)
_BENCHMARK = r"([a-zA-Z0-9_]+)"
_TIMESTAMP = r"([0-9]+\.[0-9]+)"
_CALLSITE = r"\((.+):([0-9]+)\)"
_TAG = r"([a-zA-Z0-9_]+)"
_VALUE = r"(.*)"
ParsedLine = namedtuple("ParsedLine", ["version", "benchmark", "timestamp",
"callsite", "tag", "value"])
LINE_PATTERN = re.compile(
"^{prefix} {benchmark} {timestamp} {callsite} {tag}(: |$){value}?$".format(
prefix=_PREFIX, benchmark=_BENCHMARK, timestamp=_TIMESTAMP,
callsite=_CALLSITE, tag=_TAG, value=_VALUE))
def parse_line(line): # type: (str) -> typing.Optional[ParsedLine]
match = LINE_PATTERN.match(line.strip())
if not match:
return
major, minor, micro, benchmark, timestamp = match.groups()[:5]
call_file, call_line, tag, _, value = match.groups()[5:]
return ParsedLine(version=(int(major), int(minor), int(micro)),
benchmark=benchmark, timestamp=timestamp,
callsite=(call_file, call_line), tag=tag, value=value)
def unparse_line(parsed_line): # type: (ParsedLine) -> str
version_str = "{}.{}.{}".format(*parsed_line.version)
callsite_str = "({}:{})".format(*parsed_line.callsite)
value_str = ": {}".format(parsed_line.value) if parsed_line.value else ""
return ":::MLPv{} {} {} {} {} {}".format(
version_str, parsed_line.benchmark, parsed_line.timestamp, callsite_str,
parsed_line.tag, value_str)
def get_mlperf_log():
"""Shielded import of mlperf_log module."""
try:
import mlperf_compliance
def test_mlperf_log_pip_version():
"""Check that mlperf_compliance is up to date."""
import pkg_resources
version = pkg_resources.get_distribution("mlperf_compliance")
version = tuple(int(i) for i in version.version.split("."))
if version < _MIN_VERSION:
logging.warning("mlperf_compliance is version {}, must be >= {}".format(
".".join([str(i) for i in version]),
".".join([str(i) for i in _MIN_VERSION])))
raise ImportError
return mlperf_compliance.mlperf_log
mlperf_log = test_mlperf_log_pip_version()
except ImportError:
mlperf_log = None
return mlperf_log
class Logger(object):
"""MLPerf logger indirection class.
This logger only logs for MLPerf runs, and prevents various errors associated
with not having the mlperf_compliance package installed.
"""
class Tags(object):
def __init__(self, mlperf_log):
self._enabled = False
self._mlperf_log = mlperf_log
def __getattr__(self, item):
if self._mlperf_log is None or not self._enabled:
return
return getattr(self._mlperf_log, item)
def __init__(self):
self._enabled = False
self._mlperf_log = get_mlperf_log()
self.tags = self.Tags(self._mlperf_log)
def __call__(self, enable=False):
if enable and self._mlperf_log is None:
raise ImportError("MLPerf logging was requested, but mlperf_compliance "
"module could not be loaded.")
self._enabled = enable
self.tags._enabled = enable
return self
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
self._enabled = False
self.tags._enabled = False
@property
def log_file(self):
if self._mlperf_log is None:
return
return self._mlperf_log.LOG_FILE
@property
def enabled(self):
return self._enabled
def ncf_print(self, key, value=None, stack_offset=_STACK_OFFSET,
deferred=False, extra_print=False, prefix=_NCF_PREFIX):
if self._mlperf_log is None or not self.enabled:
return
self._mlperf_log.ncf_print(key=key, value=value, stack_offset=stack_offset,
deferred=deferred, extra_print=extra_print,
prefix=prefix)
def set_ncf_root(self, path):
if self._mlperf_log is None:
return
self._mlperf_log.ROOT_DIR_NCF = path
LOGGER = Logger()
ncf_print, set_ncf_root = LOGGER.ncf_print, LOGGER.set_ncf_root
TAGS = LOGGER.tags
def clear_system_caches():
if not LOGGER.enabled:
return
ret_code = subprocess.call(
["sync && echo 3 | {} tee {}".format(SUDO, DROP_CACHE_LOC)],
shell=True)
if ret_code:
raise ValueError("Failed to clear caches")
if __name__ == "__main__":
logging.set_verbosity(logging.INFO)
with LOGGER(True):
ncf_print(key=TAGS.RUN_START)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mock objects and related functions for testing."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class MockBenchmarkLogger(object):
"""This is a mock logger that can be used in dependent tests."""
def __init__(self):
self.logged_metric = []
def log_metric(self, name, value, unit=None, global_step=None,
extras=None):
self.logged_metric.append({
"name": name,
"value": float(value),
"unit": unit,
"global_step": global_step,
"extras": extras})
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions specific to running TensorFlow on TPUs."""
import tensorflow as tf
# "local" is a magic word in the TPU cluster resolver; it informs the resolver
# to use the local CPU as the compute device. This is useful for testing and
# debugging; the code flow is ostensibly identical, but without the need to
# actually have a TPU on the other end.
LOCAL = "local"
def construct_scalar_host_call(metric_dict, model_dir, prefix=""):
"""Construct a host call to log scalars when training on TPU.
Args:
metric_dict: A dict of the tensors to be logged.
model_dir: The location to write the summary.
prefix: The prefix (if any) to prepend to the metric names.
Returns:
A tuple of (function, args_to_be_passed_to_said_function)
"""
# type: (dict, str) -> (function, list)
metric_names = list(metric_dict.keys())
def host_call_fn(global_step, *args):
"""Training host call. Creates scalar summaries for training metrics.
This function is executed on the CPU and should not directly reference
any Tensors in the rest of the `model_fn`. To pass Tensors from the
model to the `metric_fn`, provide as part of the `host_call`. See
https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
for more information.
Arguments should match the list of `Tensor` objects passed as the second
element in the tuple passed to `host_call`.
Args:
global_step: `Tensor with shape `[batch]` for the global_step
*args: Remaining tensors to log.
Returns:
List of summary ops to run on the CPU host.
"""
step = global_step[0]
with tf.compat.v1.summary.create_file_writer(
logdir=model_dir, filename_suffix=".host_call").as_default():
with tf.compat.v1.summary.always_record_summaries():
for i, name in enumerate(metric_names):
tf.compat.v1.summary.scalar(prefix + name, args[i][0], step=step)
return tf.compat.v1.summary.all_summary_ops()
# To log the current learning rate, and gradient norm for Tensorboard, the
# summary op needs to be run on the host CPU via host_call. host_call
# expects [batch_size, ...] Tensors, thus reshape to introduce a batch
# dimension. These Tensors are implicitly concatenated to
# [params['batch_size']].
global_step_tensor = tf.reshape(
tf.compat.v1.train.get_or_create_global_step(), [1])
other_tensors = [tf.reshape(metric_dict[key], [1]) for key in metric_names]
return host_call_fn, [global_step_tensor] + other_tensors
def embedding_matmul(embedding_table, values, mask, name="embedding_matmul"):
"""Performs embedding lookup via a matmul.
The matrix to be multiplied by the embedding table Tensor is constructed
via an implementation of scatter based on broadcasting embedding indices
and performing an equality comparison against a broadcasted
range(num_embedding_table_rows). All masked positions will produce an
embedding vector of zeros.
Args:
embedding_table: Tensor of embedding table.
Rank 2 (table_size x embedding dim)
values: Tensor of embedding indices. Rank 2 (batch x n_indices)
mask: Tensor of mask / weights. Rank 2 (batch x n_indices)
name: Optional name scope for created ops
Returns:
Rank 3 tensor of embedding vectors.
"""
with tf.name_scope(name):
n_embeddings = embedding_table.get_shape().as_list()[0]
batch_size, padded_size = values.shape.as_list()
emb_idcs = tf.tile(
tf.reshape(values, (batch_size, padded_size, 1)), (1, 1, n_embeddings))
emb_weights = tf.tile(
tf.reshape(mask, (batch_size, padded_size, 1)), (1, 1, n_embeddings))
col_idcs = tf.tile(
tf.reshape(tf.range(n_embeddings), (1, 1, n_embeddings)),
(batch_size, padded_size, 1))
one_hot = tf.where(
tf.equal(emb_idcs, col_idcs), emb_weights,
tf.zeros((batch_size, padded_size, n_embeddings)))
return tf.tensordot(one_hot, embedding_table, 1)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test TPU optimized matmul embedding."""
import numpy as np
import tensorflow as tf
from official.r1.utils import tpu as tpu_utils
TEST_CASES = [
dict(embedding_dim=256, vocab_size=1000, sequence_length=64,
batch_size=32, seed=54131),
dict(embedding_dim=8, vocab_size=15, sequence_length=12,
batch_size=256, seed=536413),
dict(embedding_dim=2048, vocab_size=512, sequence_length=50,
batch_size=8, seed=35124)
]
class TPUBaseTester(tf.test.TestCase):
def construct_embedding_and_values(self, embedding_dim, vocab_size,
sequence_length, batch_size, seed):
np.random.seed(seed)
embeddings = np.random.random(size=(vocab_size, embedding_dim))
embedding_table = tf.convert_to_tensor(value=embeddings, dtype=tf.float32)
tokens = np.random.randint(low=1, high=vocab_size-1,
size=(batch_size, sequence_length))
for i in range(batch_size):
tokens[i, np.random.randint(low=0, high=sequence_length-1):] = 0
values = tf.convert_to_tensor(value=tokens, dtype=tf.int32)
mask = tf.cast(tf.not_equal(values, 0), dtype=tf.float32)
return embedding_table, values, mask
def _test_embedding(self, embedding_dim, vocab_size,
sequence_length, batch_size, seed):
"""Test that matmul embedding matches embedding lookup (gather)."""
with self.test_session():
embedding_table, values, mask = self.construct_embedding_and_values(
embedding_dim=embedding_dim,
vocab_size=vocab_size,
sequence_length=sequence_length,
batch_size=batch_size,
seed=seed
)
embedding = (tf.nn.embedding_lookup(params=embedding_table, ids=values) *
tf.expand_dims(mask, -1))
matmul_embedding = tpu_utils.embedding_matmul(
embedding_table=embedding_table, values=values, mask=mask)
self.assertAllClose(embedding, matmul_embedding)
def _test_masking(self, embedding_dim, vocab_size,
sequence_length, batch_size, seed):
"""Test that matmul embedding properly zeros masked positions."""
with self.test_session():
embedding_table, values, mask = self.construct_embedding_and_values(
embedding_dim=embedding_dim,
vocab_size=vocab_size,
sequence_length=sequence_length,
batch_size=batch_size,
seed=seed
)
matmul_embedding = tpu_utils.embedding_matmul(
embedding_table=embedding_table, values=values, mask=mask)
self.assertAllClose(matmul_embedding,
matmul_embedding * tf.expand_dims(mask, -1))
def test_embedding_0(self):
self._test_embedding(**TEST_CASES[0])
def test_embedding_1(self):
self._test_embedding(**TEST_CASES[1])
def test_embedding_2(self):
self._test_embedding(**TEST_CASES[2])
def test_masking_0(self):
self._test_masking(**TEST_CASES[0])
def test_masking_1(self):
self._test_masking(**TEST_CASES[1])
def test_masking_2(self):
self._test_masking(**TEST_CASES[2])
if __name__ == "__main__":
tf.test.main()
![No Maintenance Intended](https://img.shields.io/badge/No%20Maintenance%20Intended-%E2%9C%95-red.svg)
![TensorFlow Requirement: 1.x](https://img.shields.io/badge/TensorFlow%20Requirement-1.x-brightgreen)
![TensorFlow 2 Not Supported](https://img.shields.io/badge/TensorFlow%202%20Not%20Supported-%E2%9C%95-red.svg)
# Predicting Income with the Census Income Dataset
The implementation is based on TensorFlow 1.x.
## Overview
The [Census Income Data Set](https://archive.ics.uci.edu/ml/datasets/Census+Income) contains over 48,000 samples with attributes including age, occupation, education, and income (a binary label, either `>50K` or `<=50K`). The dataset is split into roughly 32,000 training and 16,000 testing samples.
Here, we use the [wide and deep model](https://research.googleblog.com/2016/06/wide-deep-learning-better-together-with.html) to predict the income labels. The **wide model** is able to memorize interactions with data with a large number of features but not able to generalize these learned interactions on new data. The **deep model** generalizes well but is unable to learn exceptions within the data. The **wide and deep model** combines the two models and is able to generalize while learning exceptions.
For the purposes of this example code, the Census Income Data Set was chosen to allow the model to train in a reasonable amount of time. You'll notice that the deep model performs almost as well as the wide and deep model on this dataset. The wide and deep model truly shines on larger data sets with high-cardinality features, where each feature has millions/billions of unique possible values (which is the specialty of the wide model).
Finally, a key point. As a modeler and developer, think about how this dataset is used and the potential benefits and harm a model's predictions can cause. A model like this could reinforce societal biases and disparities. Is a feature relevant to the problem you want to solve, or will it introduce bias? For more information, read about [ML fairness](https://developers.google.com/machine-learning/fairness-overview/).
---
The code sample in this directory uses the high level `tf.estimator.Estimator` API. This API is great for fast iteration and quickly adapting models to your own datasets without major code overhauls. It allows you to move from single-worker training to distributed training, and it makes it easy to export model binaries for prediction.
The input function for the `Estimator` uses `tf.contrib.data.TextLineDataset`, which creates a `Dataset` object. The `Dataset` API makes it easy to apply transformations (map, batch, shuffle, etc.) to the data. [Read more here](https://www.tensorflow.org/guide/datasets).
The `Estimator` and `Dataset` APIs are both highly encouraged for fast development and efficient training.
## Running the code
First make sure you've [added the models folder to your Python path](/official/#running-the-models); otherwise you may encounter an error like `ImportError: No module named official.wide_deep`.
### Setup
The [Census Income Data Set](https://archive.ics.uci.edu/ml/datasets/Census+Income) that this sample uses for training is hosted by the [UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/). We have provided a script that downloads and cleans the necessary files.
```
python census_dataset.py
```
This will download the files to `/tmp/census_data`. To change the directory, set the `--data_dir` flag.
### Training
You can run the code locally as follows:
```
python census_main.py
```
The model is saved to `/tmp/census_model` by default, which can be changed using the `--model_dir` flag.
To run the *wide* or *deep*-only models, set the `--model_type` flag to `wide` or `deep`. Other flags are configurable as well; see `census_main.py` for details.
The final accuracy should be over 83% with any of the three model types.
You can also experiment with `-inter` and `-intra` flag to explore inter/intra op parallelism for potential better performance as follows:
```
python census_main.py --inter=<int> --intra=<int>
```
Please note the above optional inter/intra op does not affect model accuracy. These are TensorFlow framework configurations that only affect execution time.
For more details regarding the above inter/intra flags, please refer to [Optimizing_for_CPU](https://www.tensorflow.org/performance/performance_guide#optimizing_for_cpu) or [TensorFlow config.proto source code](https://github.com/tensorflow/tensorflow/blob/26b4dfa65d360f2793ad75083c797d57f8661b93/tensorflow/core/protobuf/config.proto#L165).
### TensorBoard
Run TensorBoard to inspect the details about the graph and training progression.
```
tensorboard --logdir=/tmp/census_model
```
## Inference with SavedModel
You can export the model into Tensorflow [SavedModel](https://www.tensorflow.org/guide/saved_model) format by using the argument `--export_dir`:
```
python census_main.py --export_dir /tmp/wide_deep_saved_model
```
After the model finishes training, use [`saved_model_cli`](https://www.tensorflow.org/guide/saved_model#cli_to_inspect_and_execute_savedmodel) to inspect and execute the SavedModel.
Try the following commands to inspect the SavedModel:
**Replace `${TIMESTAMP}` with the folder produced (e.g. 1524249124)**
```
# List possible tag_sets. Only one metagraph is saved, so there will be one option.
saved_model_cli show --dir /tmp/wide_deep_saved_model/${TIMESTAMP}/
# Show SignatureDefs for tag_set=serve. SignatureDefs define the outputs to show.
saved_model_cli show --dir /tmp/wide_deep_saved_model/${TIMESTAMP}/ \
--tag_set serve --all
```
### Inference
Let's use the model to predict the income group of two examples:
```
saved_model_cli run --dir /tmp/wide_deep_saved_model/${TIMESTAMP}/ \
--tag_set serve --signature_def="predict" \
--input_examples='examples=[{"age":[46.], "education_num":[10.], "capital_gain":[7688.], "capital_loss":[0.], "hours_per_week":[38.]}, {"age":[24.], "education_num":[13.], "capital_gain":[0.], "capital_loss":[0.], "hours_per_week":[50.]}]'
```
This will print out the predicted classes and class probabilities. Class 0 is the <=50k group and 1 is the >50k group.
## Additional Links
If you are interested in distributed training, take a look at [Distributed TensorFlow](https://www.tensorflow.org/deploy/distributed).
You can also [run this model on Cloud ML Engine](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction), which provides [hyperparameter tuning](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction#hyperparameter_tuning) to maximize your model's results and enables [deploying your model for prediction](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction#deploy_a_model_to_support_prediction).
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Download and clean the Census Income Dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
# pylint: disable=wrong-import-order
from absl import app as absl_app
from absl import flags
from six.moves import urllib
from six.moves import zip
import tensorflow.compat.v1 as tf
# pylint: enable=wrong-import-order
from official.utils.flags import core as flags_core
DATA_URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult'
TRAINING_FILE = 'adult.data'
TRAINING_URL = '%s/%s' % (DATA_URL, TRAINING_FILE)
EVAL_FILE = 'adult.test'
EVAL_URL = '%s/%s' % (DATA_URL, EVAL_FILE)
_CSV_COLUMNS = [
'age', 'workclass', 'fnlwgt', 'education', 'education_num',
'marital_status', 'occupation', 'relationship', 'race', 'gender',
'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
'income_bracket'
]
_CSV_COLUMN_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''],
[0], [0], [0], [''], ['']]
_HASH_BUCKET_SIZE = 1000
_NUM_EXAMPLES = {
'train': 32561,
'validation': 16281,
}
def _download_and_clean_file(filename, url):
"""Downloads data from url, and makes changes to match the CSV format."""
temp_file, _ = urllib.request.urlretrieve(url)
with tf.gfile.Open(temp_file, 'r') as temp_eval_file:
with tf.gfile.Open(filename, 'w') as eval_file:
for line in temp_eval_file:
line = line.strip()
line = line.replace(', ', ',')
if not line or ',' not in line:
continue
if line[-1] == '.':
line = line[:-1]
line += '\n'
eval_file.write(line)
tf.gfile.Remove(temp_file)
def download(data_dir):
"""Download census data if it is not already present."""
tf.gfile.MakeDirs(data_dir)
training_file_path = os.path.join(data_dir, TRAINING_FILE)
if not tf.gfile.Exists(training_file_path):
_download_and_clean_file(training_file_path, TRAINING_URL)
eval_file_path = os.path.join(data_dir, EVAL_FILE)
if not tf.gfile.Exists(eval_file_path):
_download_and_clean_file(eval_file_path, EVAL_URL)
def build_model_columns():
"""Builds a set of wide and deep feature columns."""
# Continuous variable columns
age = tf.feature_column.numeric_column('age')
education_num = tf.feature_column.numeric_column('education_num')
capital_gain = tf.feature_column.numeric_column('capital_gain')
capital_loss = tf.feature_column.numeric_column('capital_loss')
hours_per_week = tf.feature_column.numeric_column('hours_per_week')
education = tf.feature_column.categorical_column_with_vocabulary_list(
'education', [
'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college',
'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school',
'5th-6th', '10th', '1st-4th', 'Preschool', '12th'])
marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
'marital_status', [
'Married-civ-spouse', 'Divorced', 'Married-spouse-absent',
'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed'])
relationship = tf.feature_column.categorical_column_with_vocabulary_list(
'relationship', [
'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried',
'Other-relative'])
workclass = tf.feature_column.categorical_column_with_vocabulary_list(
'workclass', [
'Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov',
'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'])
# To show an example of hashing:
occupation = tf.feature_column.categorical_column_with_hash_bucket(
'occupation', hash_bucket_size=_HASH_BUCKET_SIZE)
# Transformations.
age_buckets = tf.feature_column.bucketized_column(
age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
# Wide columns and deep columns.
base_columns = [
education, marital_status, relationship, workclass, occupation,
age_buckets,
]
crossed_columns = [
tf.feature_column.crossed_column(
['education', 'occupation'], hash_bucket_size=_HASH_BUCKET_SIZE),
tf.feature_column.crossed_column(
[age_buckets, 'education', 'occupation'],
hash_bucket_size=_HASH_BUCKET_SIZE),
]
wide_columns = base_columns + crossed_columns
deep_columns = [
age,
education_num,
capital_gain,
capital_loss,
hours_per_week,
tf.feature_column.indicator_column(workclass),
tf.feature_column.indicator_column(education),
tf.feature_column.indicator_column(marital_status),
tf.feature_column.indicator_column(relationship),
# To show an example of embedding
tf.feature_column.embedding_column(occupation, dimension=8),
]
return wide_columns, deep_columns
def input_fn(data_file, num_epochs, shuffle, batch_size):
"""Generate an input function for the Estimator."""
assert tf.gfile.Exists(data_file), (
'%s not found. Please make sure you have run census_dataset.py and '
'set the --data_dir argument to the correct path.' % data_file)
def parse_csv(value):
tf.logging.info('Parsing {}'.format(data_file))
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
features = dict(list(zip(_CSV_COLUMNS, columns)))
labels = features.pop('income_bracket')
classes = tf.equal(labels, '>50K') # binary classification
return features, classes
# Extract lines from input files using the Dataset API.
dataset = tf.data.TextLineDataset(data_file)
if shuffle:
dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])
dataset = dataset.map(parse_csv, num_parallel_calls=5)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
return dataset
def define_data_download_flags():
"""Add flags specifying data download arguments."""
flags.DEFINE_string(
name="data_dir", default="/tmp/census_data/",
help=flags_core.help_wrap(
"Directory to download and extract data."))
def main(_):
download(flags.FLAGS.data_dir)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
define_data_download_flags()
absl_app.run(main)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train DNN on census income dataset."""
import os
from absl import app as absl_app
from absl import flags
import tensorflow.compat.v1 as tf
from official.r1.utils.logs import logger
from official.r1.wide_deep import census_dataset
from official.r1.wide_deep import wide_deep_run_loop
from official.utils.flags import core as flags_core
def define_census_flags():
wide_deep_run_loop.define_wide_deep_flags()
flags.adopt_module_key_flags(wide_deep_run_loop)
flags_core.set_defaults(data_dir='/tmp/census_data',
model_dir='/tmp/census_model',
train_epochs=40,
epochs_between_evals=2,
inter_op_parallelism_threads=0,
intra_op_parallelism_threads=0,
batch_size=40)
def build_estimator(model_dir, model_type, model_column_fn, inter_op, intra_op):
"""Build an estimator appropriate for the given model type."""
wide_columns, deep_columns = model_column_fn()
hidden_units = [100, 75, 50, 25]
# Create a tf.estimator.RunConfig to ensure the model is run on CPU, which
# trains faster than GPU for this model.
run_config = tf.estimator.RunConfig().replace(
session_config=tf.ConfigProto(device_count={'GPU': 0},
inter_op_parallelism_threads=inter_op,
intra_op_parallelism_threads=intra_op))
if model_type == 'wide':
return tf.estimator.LinearClassifier(
model_dir=model_dir,
feature_columns=wide_columns,
config=run_config)
elif model_type == 'deep':
return tf.estimator.DNNClassifier(
model_dir=model_dir,
feature_columns=deep_columns,
hidden_units=hidden_units,
config=run_config)
else:
return tf.estimator.DNNLinearCombinedClassifier(
model_dir=model_dir,
linear_feature_columns=wide_columns,
dnn_feature_columns=deep_columns,
dnn_hidden_units=hidden_units,
config=run_config)
def run_census(flags_obj):
"""Construct all necessary functions and call run_loop.
Args:
flags_obj: Object containing user specified flags.
"""
if flags_obj.download_if_missing:
census_dataset.download(flags_obj.data_dir)
train_file = os.path.join(flags_obj.data_dir, census_dataset.TRAINING_FILE)
test_file = os.path.join(flags_obj.data_dir, census_dataset.EVAL_FILE)
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
def train_input_fn():
return census_dataset.input_fn(
train_file, flags_obj.epochs_between_evals, True, flags_obj.batch_size)
def eval_input_fn():
return census_dataset.input_fn(test_file, 1, False, flags_obj.batch_size)
tensors_to_log = {
'average_loss': '{loss_prefix}head/truediv',
'loss': '{loss_prefix}head/weighted_loss/Sum'
}
wide_deep_run_loop.run_loop(
name="Census Income", train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
model_column_fn=census_dataset.build_model_columns,
build_estimator_fn=build_estimator,
flags_obj=flags_obj,
tensors_to_log=tensors_to_log,
early_stop=True)
def main(_):
with logger.benchmark_context(flags.FLAGS):
run_census(flags.FLAGS)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
define_census_flags()
absl_app.run(main)
39,State-gov,77516,Bachelors,13,Never-married,Adm-clerical,Not-in-family,,,2174,0,40,,<=50K
50,Self-emp-not-inc,83311,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,,,0,0,13,,<=50K
38,Private,215646,HS-grad,9,Divorced,Handlers-cleaners,Not-in-family,,,0,0,40,,<=50K
53,Private,234721,11th,7,Married-civ-spouse,Handlers-cleaners,Husband,,,0,0,40,,<=50K
28,Private,338409,Bachelors,13,Married-civ-spouse,Prof-specialty,Wife,,,0,0,40,,<=50K
37,Private,284582,Masters,14,Married-civ-spouse,Exec-managerial,Wife,,,0,0,40,,<=50K
49,Private,160187,9th,5,Married-spouse-absent,Other-service,Not-in-family,,,0,0,16,,<=50K
52,Self-emp-not-inc,209642,HS-grad,9,Married-civ-spouse,Exec-managerial,Husband,,,0,0,45,,>50K
31,Private,45781,Masters,14,Never-married,Prof-specialty,Not-in-family,,,14084,0,50,,>50K
42,Private,159449,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,,,5178,0,40,,>50K
37,Private,280464,Some-college,10,Married-civ-spouse,Exec-managerial,Husband,,,0,0,80,,>50K
30,State-gov,141297,Bachelors,13,Married-civ-spouse,Prof-specialty,Husband,,,0,0,40,,>50K
23,Private,122272,Bachelors,13,Never-married,Adm-clerical,Own-child,,,0,0,30,,<=50K
32,Private,205019,Assoc-acdm,12,Never-married,Sales,Not-in-family,,,0,0,50,,<=50K
40,Private,121772,Assoc-voc,11,Married-civ-spouse,Craft-repair,Husband,,,0,0,40,,>50K
34,Private,245487,7th-8th,4,Married-civ-spouse,Transport-moving,Husband,,,0,0,45,,<=50K
25,Self-emp-not-inc,176756,HS-grad,9,Never-married,Farming-fishing,Own-child,,,0,0,35,,<=50K
32,Private,186824,HS-grad,9,Never-married,Machine-op-inspct,Unmarried,,,0,0,40,,<=50K
38,Private,28887,11th,7,Married-civ-spouse,Sales,Husband,,,0,0,50,,<=50K
43,Self-emp-not-inc,292175,Masters,14,Divorced,Exec-managerial,Unmarried,,,0,0,45,,>50K
40,Private,193524,Doctorate,16,Married-civ-spouse,Prof-specialty,Husband,,,0,0,60,,>50K
56,Local-gov,216851,Bachelors,13,Married-civ-spouse,Tech-support,Husband,,,0,0,40,,>50K
54,?,180211,Some-college,10,Married-civ-spouse,?,Husband,,,0,0,60,,>50K
22,State-gov,311512,Some-college,10,Married-civ-spouse,Other-service,Husband,,,0,0,15,,<=50K
31,Private,84154,Some-college,10,Married-civ-spouse,Sales,Husband,,,0,0,38,,>50K
57,Federal-gov,337895,Bachelors,13,Married-civ-spouse,Prof-specialty,Husband,,,0,0,40,,>50K
47,Private,51835,Prof-school,15,Married-civ-spouse,Prof-specialty,Wife,,,0,1902,60,,>50K
50,Federal-gov,251585,Bachelors,13,Divorced,Exec-managerial,Not-in-family,,,0,0,55,,>50K
25,Private,289980,HS-grad,9,Never-married,Handlers-cleaners,Not-in-family,,,0,0,35,,<=50K
42,Private,116632,Doctorate,16,Married-civ-spouse,Prof-specialty,Husband,,,0,0,45,,>50K
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import logging
import tensorflow.compat.v1 as tf
from official.r1.wide_deep import census_dataset
from official.r1.wide_deep import census_main
from official.utils.testing import integration
logging.set_verbosity(logging.ERROR)
TEST_INPUT = ('18,Self-emp-not-inc,987,Bachelors,12,Married-civ-spouse,abc,'
'Husband,zyx,wvu,34,56,78,tsr,<=50K')
TEST_INPUT_VALUES = {
'age': 18,
'education_num': 12,
'capital_gain': 34,
'capital_loss': 56,
'hours_per_week': 78,
'education': 'Bachelors',
'marital_status': 'Married-civ-spouse',
'relationship': 'Husband',
'workclass': 'Self-emp-not-inc',
'occupation': 'abc',
}
TEST_CSV = os.path.join(os.path.dirname(__file__), 'census_test.csv')
class BaseTest(tf.test.TestCase):
"""Tests for Wide Deep model."""
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(BaseTest, cls).setUpClass()
census_main.define_census_flags()
def setUp(self):
# Create temporary CSV file
self.temp_dir = self.get_temp_dir()
self.input_csv = os.path.join(self.temp_dir, 'test.csv')
with tf.io.gfile.GFile(self.input_csv, 'w') as temp_csv:
temp_csv.write(TEST_INPUT)
with tf.io.gfile.GFile(TEST_CSV, 'r') as temp_csv:
test_csv_contents = temp_csv.read()
# Used for end-to-end tests.
for fname in [census_dataset.TRAINING_FILE, census_dataset.EVAL_FILE]:
with tf.io.gfile.GFile(
os.path.join(self.temp_dir, fname), 'w') as test_csv:
test_csv.write(test_csv_contents)
def test_input_fn(self):
dataset = census_dataset.input_fn(self.input_csv, 1, False, 1)
features, labels = dataset.make_one_shot_iterator().get_next()
with self.test_session() as sess:
features, labels = sess.run((features, labels))
# Compare the two features dictionaries.
for key in TEST_INPUT_VALUES:
self.assertTrue(key in features)
self.assertEqual(len(features[key]), 1)
feature_value = features[key][0]
# Convert from bytes to string for Python 3.
if isinstance(feature_value, bytes):
feature_value = feature_value.decode()
self.assertEqual(TEST_INPUT_VALUES[key], feature_value)
self.assertFalse(labels)
def build_and_test_estimator(self, model_type):
"""Ensure that model trains and minimizes loss."""
model = census_main.build_estimator(
self.temp_dir, model_type,
model_column_fn=census_dataset.build_model_columns,
inter_op=0, intra_op=0)
# Train for 1 step to initialize model and evaluate initial loss
def get_input_fn(num_epochs, shuffle, batch_size):
def input_fn():
return census_dataset.input_fn(
TEST_CSV, num_epochs=num_epochs, shuffle=shuffle,
batch_size=batch_size)
return input_fn
model.train(input_fn=get_input_fn(1, True, 1), steps=1)
initial_results = model.evaluate(input_fn=get_input_fn(1, False, 1))
# Train for 100 epochs at batch size 3 and evaluate final loss
model.train(input_fn=get_input_fn(100, True, 3))
final_results = model.evaluate(input_fn=get_input_fn(1, False, 1))
print('%s initial results:' % model_type, initial_results)
print('%s final results:' % model_type, final_results)
# Ensure loss has decreased, while accuracy and both AUCs have increased.
self.assertLess(final_results['loss'], initial_results['loss'])
self.assertGreater(final_results['auc'], initial_results['auc'])
self.assertGreater(final_results['auc_precision_recall'],
initial_results['auc_precision_recall'])
self.assertGreater(final_results['accuracy'], initial_results['accuracy'])
def test_wide_deep_estimator_training(self):
self.build_and_test_estimator('wide_deep')
def test_end_to_end_wide(self):
integration.run_synthetic(
main=census_main.main, tmp_root=self.get_temp_dir(),
extra_flags=[
'--data_dir', self.get_temp_dir(),
'--model_type', 'wide',
'--download_if_missing=false'
],
synth=False)
def test_end_to_end_deep(self):
integration.run_synthetic(
main=census_main.main, tmp_root=self.get_temp_dir(),
extra_flags=[
'--data_dir', self.get_temp_dir(),
'--model_type', 'deep',
'--download_if_missing=false'
],
synth=False)
def test_end_to_end_wide_deep(self):
integration.run_synthetic(
main=census_main.main, tmp_root=self.get_temp_dir(),
extra_flags=[
'--data_dir', self.get_temp_dir(),
'--model_type', 'wide_deep',
'--download_if_missing=false'
],
synth=False)
if __name__ == '__main__':
tf.disable_eager_execution()
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Prepare MovieLens dataset for wide-deep."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import os
# pylint: disable=wrong-import-order
from absl import app as absl_app
from absl import flags
import numpy as np
import tensorflow.compat.v1 as tf
# pylint: enable=wrong-import-order
from official.recommendation import movielens
from official.r1.utils.data import file_io
from official.utils.flags import core as flags_core
_BUFFER_SUBDIR = "wide_deep_buffer"
_FEATURE_MAP = {
movielens.USER_COLUMN: tf.compat.v1.FixedLenFeature([1], dtype=tf.int64),
movielens.ITEM_COLUMN: tf.compat.v1.FixedLenFeature([1], dtype=tf.int64),
movielens.TIMESTAMP_COLUMN: tf.compat.v1.FixedLenFeature([1],
dtype=tf.int64),
movielens.GENRE_COLUMN: tf.compat.v1.FixedLenFeature(
[movielens.N_GENRE], dtype=tf.int64),
movielens.RATING_COLUMN: tf.compat.v1.FixedLenFeature([1],
dtype=tf.float32),
}
_BUFFER_SIZE = {
movielens.ML_1M: {"train": 107978119, "eval": 26994538},
movielens.ML_20M: {"train": 2175203810, "eval": 543802008}
}
_USER_EMBEDDING_DIM = 16
_ITEM_EMBEDDING_DIM = 64
def build_model_columns(dataset):
"""Builds a set of wide and deep feature columns."""
user_id = tf.feature_column.categorical_column_with_vocabulary_list(
movielens.USER_COLUMN, range(1, movielens.NUM_USER_IDS[dataset]))
user_embedding = tf.feature_column.embedding_column(
user_id, _USER_EMBEDDING_DIM, max_norm=np.sqrt(_USER_EMBEDDING_DIM))
item_id = tf.feature_column.categorical_column_with_vocabulary_list(
movielens.ITEM_COLUMN, range(1, movielens.NUM_ITEM_IDS))
item_embedding = tf.feature_column.embedding_column(
item_id, _ITEM_EMBEDDING_DIM, max_norm=np.sqrt(_ITEM_EMBEDDING_DIM))
time = tf.feature_column.numeric_column(movielens.TIMESTAMP_COLUMN)
genres = tf.feature_column.numeric_column(
movielens.GENRE_COLUMN, shape=(movielens.N_GENRE,), dtype=tf.uint8)
deep_columns = [user_embedding, item_embedding, time, genres]
wide_columns = []
return wide_columns, deep_columns
def _deserialize(examples_serialized):
features = tf.parse_example(examples_serialized, _FEATURE_MAP)
return features, features[movielens.RATING_COLUMN] / movielens.MAX_RATING
def _buffer_path(data_dir, dataset, name):
return os.path.join(data_dir, _BUFFER_SUBDIR,
"{}_{}_buffer".format(dataset, name))
def _df_to_input_fn(df, name, dataset, data_dir, batch_size, repeat, shuffle):
"""Serialize a dataframe and write it to a buffer file."""
buffer_path = _buffer_path(data_dir, dataset, name)
expected_size = _BUFFER_SIZE[dataset].get(name)
file_io.write_to_buffer(
dataframe=df, buffer_path=buffer_path,
columns=list(_FEATURE_MAP.keys()), expected_size=expected_size)
def input_fn():
dataset = tf.data.TFRecordDataset(buffer_path)
# batch comes before map because map can deserialize multiple examples.
dataset = dataset.batch(batch_size)
dataset = dataset.map(_deserialize, num_parallel_calls=16)
if shuffle:
dataset = dataset.shuffle(shuffle)
dataset = dataset.repeat(repeat)
return dataset.prefetch(1)
return input_fn
def _check_buffers(data_dir, dataset):
train_path = os.path.join(data_dir, _BUFFER_SUBDIR,
"{}_{}_buffer".format(dataset, "train"))
eval_path = os.path.join(data_dir, _BUFFER_SUBDIR,
"{}_{}_buffer".format(dataset, "eval"))
if not tf.gfile.Exists(train_path) or not tf.gfile.Exists(eval_path):
return False
return all([
tf.gfile.Stat(_buffer_path(data_dir, dataset, "train")).length ==
_BUFFER_SIZE[dataset]["train"],
tf.gfile.Stat(_buffer_path(data_dir, dataset, "eval")).length ==
_BUFFER_SIZE[dataset]["eval"],
])
def construct_input_fns(dataset, data_dir, batch_size=16, repeat=1):
"""Construct train and test input functions, as well as the column fn."""
if _check_buffers(data_dir, dataset):
train_df, eval_df = None, None
else:
df = movielens.csv_to_joint_dataframe(dataset=dataset, data_dir=data_dir)
df = movielens.integerize_genres(dataframe=df)
df = df.drop(columns=[movielens.TITLE_COLUMN])
train_df = df.sample(frac=0.8, random_state=0)
eval_df = df.drop(train_df.index)
train_df = train_df.reset_index(drop=True)
eval_df = eval_df.reset_index(drop=True)
train_input_fn = _df_to_input_fn(
df=train_df, name="train", dataset=dataset, data_dir=data_dir,
batch_size=batch_size, repeat=repeat,
shuffle=movielens.NUM_RATINGS[dataset])
eval_input_fn = _df_to_input_fn(
df=eval_df, name="eval", dataset=dataset, data_dir=data_dir,
batch_size=batch_size, repeat=repeat, shuffle=None)
model_column_fn = functools.partial(build_model_columns, dataset=dataset)
train_input_fn()
return train_input_fn, eval_input_fn, model_column_fn
def main(_):
movielens.download(dataset=flags.FLAGS.dataset, data_dir=flags.FLAGS.data_dir)
construct_input_fns(flags.FLAGS.dataset, flags.FLAGS.data_dir)
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
movielens.define_data_download_flags()
flags.adopt_module_key_flags(movielens)
flags_core.set_defaults(dataset="ml-1m")
absl_app.run(main)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train DNN on Kaggle movie dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app as absl_app
from absl import flags
import tensorflow.compat.v1 as tf
from official.r1.utils.logs import logger
from official.r1.wide_deep import movielens_dataset
from official.r1.wide_deep import wide_deep_run_loop
from official.recommendation import movielens
from official.utils.flags import core as flags_core
def define_movie_flags():
"""Define flags for movie dataset training."""
wide_deep_run_loop.define_wide_deep_flags()
flags.DEFINE_enum(
name="dataset", default=movielens.ML_1M,
enum_values=movielens.DATASETS, case_sensitive=False,
help=flags_core.help_wrap("Dataset to be trained and evaluated."))
flags.adopt_module_key_flags(wide_deep_run_loop)
flags_core.set_defaults(data_dir="/tmp/movielens-data/",
model_dir='/tmp/movie_model',
model_type="deep",
train_epochs=50,
epochs_between_evals=5,
inter_op_parallelism_threads=0,
intra_op_parallelism_threads=0,
batch_size=256)
@flags.validator("stop_threshold",
message="stop_threshold not supported for movielens model")
def _no_stop(stop_threshold):
return stop_threshold is None
def build_estimator(model_dir, model_type, model_column_fn, inter_op, intra_op):
"""Build an estimator appropriate for the given model type."""
if model_type != "deep":
raise NotImplementedError("movie dataset only supports `deep` model_type")
_, deep_columns = model_column_fn()
hidden_units = [256, 256, 256, 128]
run_config = tf.estimator.RunConfig().replace(
session_config=tf.ConfigProto(device_count={'GPU': 0},
inter_op_parallelism_threads=inter_op,
intra_op_parallelism_threads=intra_op))
return tf.estimator.DNNRegressor(
model_dir=model_dir,
feature_columns=deep_columns,
hidden_units=hidden_units,
optimizer=tf.compat.v1.train.AdamOptimizer(),
activation_fn=tf.nn.sigmoid,
dropout=0.3,
loss_reduction=tf.losses.Reduction.MEAN)
def run_movie(flags_obj):
"""Construct all necessary functions and call run_loop.
Args:
flags_obj: Object containing user specified flags.
"""
if flags_obj.download_if_missing:
movielens.download(dataset=flags_obj.dataset, data_dir=flags_obj.data_dir)
train_input_fn, eval_input_fn, model_column_fn = \
movielens_dataset.construct_input_fns(
dataset=flags_obj.dataset, data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size, repeat=flags_obj.epochs_between_evals)
tensors_to_log = {
'loss': '{loss_prefix}head/weighted_loss/value'
}
wide_deep_run_loop.run_loop(
name="MovieLens", train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
model_column_fn=model_column_fn,
build_estimator_fn=build_estimator,
flags_obj=flags_obj,
tensors_to_log=tensors_to_log,
early_stop=False)
def main(_):
with logger.benchmark_context(flags.FLAGS):
run_movie(flags.FLAGS)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
define_movie_flags()
absl_app.run(main)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Core run logic for TensorFlow Wide & Deep Tutorial using tf.estimator API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
from absl import app as absl_app
from absl import flags
import tensorflow.compat.v1 as tf
from official.r1.utils.logs import hooks_helper
from official.r1.utils.logs import logger
from official.utils.flags import core as flags_core
from official.utils.misc import model_helpers
LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'}
def define_wide_deep_flags():
"""Add supervised learning flags, as well as wide-deep model type."""
flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True, stop_threshold=True,
hooks=True, export_dir=True)
flags_core.define_benchmark()
flags_core.define_performance(
num_parallel_calls=False, inter_op=True, intra_op=True,
synthetic_data=False, max_train_steps=False, dtype=False,
all_reduce_alg=False)
flags.adopt_module_key_flags(flags_core)
flags.DEFINE_enum(
name="model_type", short_name="mt", default="wide_deep",
enum_values=['wide', 'deep', 'wide_deep'],
help="Select model topology.")
flags.DEFINE_boolean(
name="download_if_missing", default=True, help=flags_core.help_wrap(
"Download data to data_dir if it is not already present."))
def export_model(model, model_type, export_dir, model_column_fn):
"""Export to SavedModel format.
Args:
model: Estimator object
model_type: string indicating model type. "wide", "deep" or "wide_deep"
export_dir: directory to export the model.
model_column_fn: Function to generate model feature columns.
"""
wide_columns, deep_columns = model_column_fn()
if model_type == 'wide':
columns = wide_columns
elif model_type == 'deep':
columns = deep_columns
else:
columns = wide_columns + deep_columns
feature_spec = tf.feature_column.make_parse_example_spec(columns)
example_input_fn = (
tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec))
model.export_savedmodel(export_dir, example_input_fn,
strip_default_attrs=True)
def run_loop(name, train_input_fn, eval_input_fn, model_column_fn,
build_estimator_fn, flags_obj, tensors_to_log, early_stop=False):
"""Define training loop."""
model_helpers.apply_clean(flags.FLAGS)
model = build_estimator_fn(
model_dir=flags_obj.model_dir, model_type=flags_obj.model_type,
model_column_fn=model_column_fn,
inter_op=flags_obj.inter_op_parallelism_threads,
intra_op=flags_obj.intra_op_parallelism_threads)
run_params = {
'batch_size': flags_obj.batch_size,
'train_epochs': flags_obj.train_epochs,
'model_type': flags_obj.model_type,
}
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info('wide_deep', name, run_params,
test_id=flags_obj.benchmark_test_id)
loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '')
tensors_to_log = {k: v.format(loss_prefix=loss_prefix)
for k, v in tensors_to_log.items()}
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks, model_dir=flags_obj.model_dir,
batch_size=flags_obj.batch_size, tensors_to_log=tensors_to_log)
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
for n in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
model.train(input_fn=train_input_fn, hooks=train_hooks)
results = model.evaluate(input_fn=eval_input_fn)
# Display evaluation metrics
tf.logging.info('Results at epoch %d / %d',
(n + 1) * flags_obj.epochs_between_evals,
flags_obj.train_epochs)
tf.logging.info('-' * 60)
for key in sorted(results):
tf.logging.info('%s: %s' % (key, results[key]))
benchmark_logger.log_evaluation_result(results)
if early_stop and model_helpers.past_stop_threshold(
flags_obj.stop_threshold, results['accuracy']):
break
# Export the model
if flags_obj.export_dir is not None:
export_model(model, flags_obj.model_type, flags_obj.export_dir,
model_column_fn)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment