".github/vscode:/vscode.git/clone" did not exist on "fd5ce576a428270e2f7ef270e9cbb4ea657ff026"
Unverified Commit 7cfb6bbd authored by Karmel Allison's avatar Karmel Allison Committed by GitHub
Browse files

Glint everything (#3654)

* Glint everything

* Adding rcfile and pylinting

* Extra newline

* Few last lints
parent adfd5a3a
...@@ -70,6 +70,5 @@ class BenchmarkLogger(object): ...@@ -70,6 +70,5 @@ class BenchmarkLogger(object):
json.dump(metric, f) json.dump(metric, f)
f.write("\n") f.write("\n")
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
tf.logging.warning("Failed to dump metric to log file: name %s, value %s, error %s", tf.logging.warning("Failed to dump metric to log file: "
name, value, e) "name %s, value %s, error %s", name, value, e)
...@@ -19,14 +19,13 @@ from __future__ import absolute_import ...@@ -19,14 +19,13 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import json import json
import os import os
import tempfile import tempfile
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logging import logger from official.utils.logging import logger
import tensorflow as tf
class BenchmarkLoggerTest(tf.test.TestCase): class BenchmarkLoggerTest(tf.test.TestCase):
......
...@@ -33,10 +33,10 @@ def run_synthetic(main, tmp_root, extra_flags=None): ...@@ -33,10 +33,10 @@ def run_synthetic(main, tmp_root, extra_flags=None):
very limited run is performed using synthetic data. very limited run is performed using synthetic data.
Args: Args:
main: The primary function used to excercise a code path. Generally this main: The primary function used to exercise a code path. Generally this
function is "<MODULE>.main(argv)". function is "<MODULE>.main(argv)".
tmp_root: Root path for the temp directory created by the test class. tmp_root: Root path for the temp directory created by the test class.
extra_flags: Additional flags passed by the the caller of this function. extra_flags: Additional flags passed by the caller of this function.
""" """
extra_flags = [] if extra_flags is None else extra_flags extra_flags = [] if extra_flags is None else extra_flags
......
[MESSAGES CONTROL]
disable=R,W,
bad-option-value
[REPORTS]
# Tells whether to display a full report or only the messages
reports=no
# Activate the evaluation score.
score=no
[BASIC]
# Regular expression matching correct argument names
argument-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct attribute names
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
# Regular expression matching correct class attribute names
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct class names
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
# Regular expression matching correct constant names
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=10
# Regular expression matching correct function names
function-rgx=^(?:(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
# Good variable names which should always be accepted, separated by a comma
good-names=main,_
# Regular expression matching correct inline iteration names
inlinevar-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct method names
method-rgx=^(?:(?P<exempt>__[a-z0-9_]+__|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
# Regular expression matching correct module names
module-rgx=^(_?[a-z][a-z0-9_]*)|__init__|PRESUBMIT|PRESUBMIT_unittest$
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=(__.*__|main|.*ArgParser)
# Naming hint for variable names
variable-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct variable names
variable-rgx=^[a-z][a-z0-9_]*$
[TYPECHECK]
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=official, official.*, tensorflow, tensorflow.*, LazyLoader
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,__new__,setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,_fields,_replace,_source,_make
# This is deprecated, because it is not used anymore.
#ignore-iface-methods=
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls,class_
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[DESIGN]
# Argument names that match this expression will be ignored. Default to name
# with leading underscore
ignored-argument-names=_.*
# Maximum number of arguments for function / method
max-args=5
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Maximum number of branch for function / method body
max-branches=12
# Maximum number of locals for function / method body
max-locals=15
# Maximum number of parents for a class (see R0901).
max-parents=7
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of return / yield for function / method body
max-returns=6
# Maximum number of statements in function / method body
max-statements=50
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=StandardError,Exception,BaseException
[FORMAT]
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Maximum number of characters on a single line.
max-line-length=80
# Maximum number of lines in a module
max-module-lines=99999
# List of optional constructs for which whitespace checking is disabled
no-space-check=
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=yes
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid to define new builtins when possible.
additional-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,_cb
# A regular expression matching the name of dummy variables (i.e. expectedly
# not used).
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
# Tells whether we should check for unused import in __init__ files.
init-import=no
...@@ -55,7 +55,7 @@ def _download_and_clean_file(filename, url): ...@@ -55,7 +55,7 @@ def _download_and_clean_file(filename, url):
tf.gfile.Remove(temp_file) tf.gfile.Remove(temp_file)
def main(unused_argv): def main(_):
if not tf.gfile.Exists(FLAGS.data_dir): if not tf.gfile.Exists(FLAGS.data_dir):
tf.gfile.MkDir(FLAGS.data_dir) tf.gfile.MkDir(FLAGS.data_dir)
......
...@@ -22,9 +22,9 @@ import os ...@@ -22,9 +22,9 @@ import os
import shutil import shutil
import sys import sys
import tensorflow as tf import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.arg_parsers import parsers # pylint: disable=g-bad-import-order from official.utils.arg_parsers import parsers
from official.utils.logging import hooks_helper from official.utils.logging import hooks_helper
_CSV_COLUMNS = [ _CSV_COLUMNS = [
...@@ -171,7 +171,7 @@ def input_fn(data_file, num_epochs, shuffle, batch_size): ...@@ -171,7 +171,7 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
return dataset return dataset
def main(unused_argv): def main(_):
# Clean up the model directory if present # Clean up the model directory if present
shutil.rmtree(FLAGS.model_dir, ignore_errors=True) shutil.rmtree(FLAGS.model_dir, ignore_errors=True)
model = build_estimator(FLAGS.model_dir, FLAGS.model_type) model = build_estimator(FLAGS.model_dir, FLAGS.model_type)
...@@ -179,6 +179,13 @@ def main(unused_argv): ...@@ -179,6 +179,13 @@ def main(unused_argv):
train_file = os.path.join(FLAGS.data_dir, 'adult.data') train_file = os.path.join(FLAGS.data_dir, 'adult.data')
test_file = os.path.join(FLAGS.data_dir, 'adult.test') test_file = os.path.join(FLAGS.data_dir, 'adult.test')
# Train and evaluate the model every `FLAGS.epochs_per_eval` epochs.
def train_input_fn():
return input_fn(train_file, FLAGS.epochs_per_eval, True, FLAGS.batch_size)
def eval_input_fn():
return input_fn(test_file, 1, False, FLAGS.batch_size)
train_hooks = hooks_helper.get_train_hooks( train_hooks = hooks_helper.get_train_hooks(
FLAGS.hooks, batch_size=FLAGS.batch_size, FLAGS.hooks, batch_size=FLAGS.batch_size,
tensors_to_log={'average_loss': 'head/truediv', tensors_to_log={'average_loss': 'head/truediv',
...@@ -186,13 +193,8 @@ def main(unused_argv): ...@@ -186,13 +193,8 @@ def main(unused_argv):
# Train and evaluate the model every `FLAGS.epochs_between_evals` epochs. # Train and evaluate the model every `FLAGS.epochs_between_evals` epochs.
for n in range(FLAGS.train_epochs // FLAGS.epochs_between_evals): for n in range(FLAGS.train_epochs // FLAGS.epochs_between_evals):
model.train( model.train(input_fn=train_input_fn, hooks=train_hooks)
input_fn=lambda: input_fn(train_file, FLAGS.epochs_between_evals, True, results = model.evaluate(input_fn=eval_input_fn)
FLAGS.batch_size),
hooks=train_hooks)
results = model.evaluate(input_fn=lambda: input_fn(
test_file, 1, False, FLAGS.batch_size))
# Display evaluation metrics # Display evaluation metrics
print('Results at epoch', (n + 1) * FLAGS.epochs_between_evals) print('Results at epoch', (n + 1) * FLAGS.epochs_between_evals)
...@@ -204,6 +206,7 @@ def main(unused_argv): ...@@ -204,6 +206,7 @@ def main(unused_argv):
class WideDeepArgParser(argparse.ArgumentParser): class WideDeepArgParser(argparse.ArgumentParser):
"""Argument parser for running the wide deep model.""" """Argument parser for running the wide deep model."""
def __init__(self): def __init__(self):
super(WideDeepArgParser, self).__init__(parents=[parsers.BaseParser()]) super(WideDeepArgParser, self).__init__(parents=[parsers.BaseParser()])
self.add_argument( self.add_argument(
......
...@@ -19,7 +19,7 @@ from __future__ import print_function ...@@ -19,7 +19,7 @@ from __future__ import print_function
import os import os
import tensorflow as tf import tensorflow as tf # pylint: disable=g-bad-import-order
from official.wide_deep import wide_deep from official.wide_deep import wide_deep
...@@ -45,6 +45,7 @@ TEST_CSV = os.path.join(os.path.dirname(__file__), 'wide_deep_test.csv') ...@@ -45,6 +45,7 @@ TEST_CSV = os.path.join(os.path.dirname(__file__), 'wide_deep_test.csv')
class BaseTest(tf.test.TestCase): class BaseTest(tf.test.TestCase):
"""Tests for Wide Deep model."""
def setUp(self): def setUp(self):
# Create temporary CSV file # Create temporary CSV file
...@@ -79,21 +80,19 @@ class BaseTest(tf.test.TestCase): ...@@ -79,21 +80,19 @@ class BaseTest(tf.test.TestCase):
model = wide_deep.build_estimator(self.temp_dir, model_type) model = wide_deep.build_estimator(self.temp_dir, model_type)
# Train for 1 step to initialize model and evaluate initial loss # Train for 1 step to initialize model and evaluate initial loss
model.train( def get_input_fn(num_epochs, shuffle, batch_size):
input_fn=lambda: wide_deep.input_fn( def input_fn():
TEST_CSV, num_epochs=1, shuffle=True, batch_size=1), return wide_deep.input_fn(
steps=1) TEST_CSV, num_epochs=num_epochs, shuffle=shuffle,
initial_results = model.evaluate( batch_size=batch_size)
input_fn=lambda: wide_deep.input_fn( return input_fn
TEST_CSV, num_epochs=1, shuffle=False, batch_size=1))
model.train(input_fn=get_input_fn(1, True, 1), steps=1)
initial_results = model.evaluate(input_fn=get_input_fn(1, False, 1))
# Train for 100 epochs at batch size 3 and evaluate final loss # Train for 100 epochs at batch size 3 and evaluate final loss
model.train( model.train(input_fn=get_input_fn(100, True, 3))
input_fn=lambda: wide_deep.input_fn( final_results = model.evaluate(input_fn=get_input_fn(1, False, 1))
TEST_CSV, num_epochs=100, shuffle=True, batch_size=3))
final_results = model.evaluate(
input_fn=lambda: wide_deep.input_fn(
TEST_CSV, num_epochs=1, shuffle=False, batch_size=1))
print('%s initial results:' % model_type, initial_results) print('%s initial results:' % model_type, initial_results)
print('%s final results:' % model_type, final_results) print('%s final results:' % model_type, final_results)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment