Commit c7369689 authored by Jaehong Kim's avatar Jaehong Kim Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 381320453
parent 83ad7bda
......@@ -32,12 +32,16 @@ from official.core import exp_factory
from official.modeling import hyperparams
def _get_benchmark_params(benchmark_models):
def _get_benchmark_params(benchmark_models, eval_tflite=False):
"""Formats benchmark params into a list."""
parameterized_benchmark_params = []
for _, benchmarks in benchmark_models.items():
for name, params in benchmarks.items():
for execution_mode in ['performance', 'accuracy']:
if eval_tflite:
execution_modes = ['performance', 'tflite_accuracy']
else:
execution_modes = ['performance', 'accuracy']
for execution_mode in execution_modes:
benchmark_name = '{}.{}'.format(name, execution_mode)
benchmark_params = (
benchmark_name, # First arg is used by ParameterizedBenchmark.
......@@ -66,7 +70,8 @@ class BaseBenchmark( # pylint: disable=undefined-variable
_benchmark_parameters = _get_benchmark_params(
benchmark_definitions.VISION_BENCHMARKS) + _get_benchmark_params(
benchmark_definitions.NLP_BENCHMARKS)
benchmark_definitions.NLP_BENCHMARKS) + _get_benchmark_params(
benchmark_definitions.QAT_BENCHMARKS, True)
def __init__(self,
output_dir=None,
......@@ -144,7 +149,7 @@ class BaseBenchmark( # pylint: disable=undefined-variable
execution_mode, params, self._get_model_dir(benchmark_name))
metrics = []
if execution_mode == 'accuracy':
if execution_mode in ['accuracy', 'tflite_accuracy']:
for metric_bound in metric_bounds:
metric = {
'name': metric_bound['name'],
......
......@@ -51,3 +51,6 @@ VISION_BENCHMARKS = {
NLP_BENCHMARKS = {
}
QAT_BENCHMARKS = {
}
......@@ -21,6 +21,7 @@ from typing import Any, Mapping
from absl import logging
import orbit
import tensorflow as tf
from official.benchmark import tflite_utils
from official.common import distribute_utils
from official.core import config_definitions
from official.core import task_factory
......@@ -37,8 +38,8 @@ def run_benchmark(
"""Runs benchmark for a specific experiment.
Args:
execution_mode: A 'str', specifying the mode. Can be 'accuracy', or
'performance'.
execution_mode: A 'str', specifying the mode. Can be 'accuracy',
'performance', or 'tflite_accuracy'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
distribution_strategy: A tf.distribute.Strategy to use. If specified,
......@@ -46,6 +47,9 @@ def run_benchmark(
Returns:
benchmark_data: returns benchmark data in dict format.
Raises:
NotImplementedError: If try to use unsupported setup.
"""
# For GPU runs, allow option to set thread mode
......@@ -77,7 +81,7 @@ def run_benchmark(
trainer.initialize()
steps_per_loop = params.trainer.steps_per_loop if (
execution_mode == 'accuracy') else 100
execution_mode in ['accuracy', 'tflite_accuracy']) else 100
controller = orbit.Controller(
strategy=strategy,
trainer=trainer,
......@@ -105,6 +109,10 @@ def run_benchmark(
benchmark_data = {'metrics': eval_logs}
elif execution_mode == 'performance':
benchmark_data = {}
elif execution_mode == 'tflite_accuracy':
eval_logs = tflite_utils.train_and_evaluate(
params, task, trainer, controller)
benchmark_data = {'metrics': eval_logs}
else:
raise NotImplementedError(
'The benchmark execution mode is not implemented: %s' %
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFLite utils."""
import orbit
from official.core import base_task
from official.core import base_trainer
from official.core import config_definitions
def train_and_evaluate(
params: config_definitions.ExperimentConfig,
task: base_task.Task,
trainer: base_trainer.Trainer,
controller: orbit.Controller):
"""Train and evaluate on TFLite."""
raise NotImplementedError('train_and_evaluate on tflite_utils is not '
'implemented yet.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment