"...resnet50_tensorflow.git" did not exist on "6f4dff46a40e73de03dc4a7fe87363956bb76522"
Commit b9ef963d authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Fix all lint errors in official/vision/image_classification/

PiperOrigin-RevId: 266458583
parent a6f9945a
...@@ -38,7 +38,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback): ...@@ -38,7 +38,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
N.B. Only support Keras optimizers, not TF optimizers. N.B. Only support Keras optimizers, not TF optimizers.
Args: Attributes:
schedule: a function that takes an epoch index and a batch index as input schedule: a function that takes an epoch index and a batch index as input
(both integer, indexed from 0) and returns a new learning rate as (both integer, indexed from 0) and returns a new learning rate as
output (float). output (float).
...@@ -313,6 +313,7 @@ def define_keras_flags(dynamic_loss_scale=True): ...@@ -313,6 +313,7 @@ def define_keras_flags(dynamic_loss_scale=True):
name='enable_get_next_as_optional', default=False, name='enable_get_next_as_optional', default=False,
help='Enable get_next_as_optional behavior in DistributedIterator.') help='Enable get_next_as_optional behavior in DistributedIterator.')
def get_synth_input_fn(height, width, num_channels, num_classes, def get_synth_input_fn(height, width, num_channels, num_classes,
dtype=tf.float32, drop_remainder=True): dtype=tf.float32, drop_remainder=True):
"""Returns an input function that returns a dataset with random data. """Returns an input function that returns a dataset with random data.
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import print_function from __future__ import print_function
# pylint: disable=g-bad-import-order
from mock import Mock from mock import Mock
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
......
...@@ -18,7 +18,6 @@ from __future__ import absolute_import ...@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import app as absl_app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
......
...@@ -23,6 +23,7 @@ from absl import flags ...@@ -23,6 +23,7 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.benchmark.models import trivial_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
...@@ -31,7 +32,6 @@ from official.utils.misc import model_helpers ...@@ -31,7 +32,6 @@ from official.utils.misc import model_helpers
from official.vision.image_classification import common from official.vision.image_classification import common
from official.vision.image_classification import imagenet_preprocessing from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_model from official.vision.image_classification import resnet_model
from official.benchmark.models import trivial_model
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80) (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
...@@ -182,7 +182,7 @@ def run(flags_obj): ...@@ -182,7 +182,7 @@ def run(flags_obj):
with strategy_scope: with strategy_scope:
optimizer = common.get_optimizer(lr_schedule) optimizer = common.get_optimizer(lr_schedule)
if flags_obj.fp16_implementation == "graph_rewrite": if flags_obj.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and # which will ensure tf.compat.v2.keras.mixed_precision and
...@@ -190,7 +190,7 @@ def run(flags_obj): ...@@ -190,7 +190,7 @@ def run(flags_obj):
# up. # up.
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer) optimizer)
# TODO(hongkuny): Remove trivial model usage and move it to benchmark. # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
if flags_obj.use_trivial_model: if flags_obj.use_trivial_model:
model = trivial_model.trivial_model( model = trivial_model.trivial_model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment