Commit 01748b24 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal changes.

PiperOrigin-RevId: 301237793
parent ad710aa1
...@@ -26,14 +26,13 @@ from absl import logging ...@@ -26,14 +26,13 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
import tensorflow_model_optimization as tfmot import tensorflow_model_optimization as tfmot
from official.benchmark.models import trivial_model
from official.modeling import performance from official.modeling import performance
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
from official.vision.image_classification import test_utils
from official.vision.image_classification.resnet import common from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import imagenet_preprocessing from official.vision.image_classification.resnet import imagenet_preprocessing
from official.vision.image_classification.resnet import resnet_model from official.vision.image_classification.resnet import resnet_model
...@@ -180,8 +179,7 @@ def run(flags_obj): ...@@ -180,8 +179,7 @@ def run(flags_obj):
# TODO(hongkuny): Remove trivial model usage and move it to benchmark. # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
if flags_obj.use_trivial_model: if flags_obj.use_trivial_model:
model = trivial_model.trivial_model( model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES)
imagenet_preprocessing.NUM_CLASSES)
elif flags_obj.model == 'resnet50_v1.5': elif flags_obj.model == 'resnet50_v1.5':
model = resnet_model.resnet50( model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES) num_classes=imagenet_preprocessing.NUM_CLASSES)
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""A trivial model for Keras.""" """Test utilities for image classification tasks."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment