Commit 31ce1788 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 405501883
parent 26d8e705
......@@ -45,9 +45,7 @@ class EdgetpuBertTrainerTest(tf.test.TestCase):
def setUp(self):
super(EdgetpuBertTrainerTest, self).setUp()
config_path = 'third_party/tensorflow_models/official/projects/edgetpu/nlp/experiments/mobilebert_edgetpu_m.yaml'
self.experiment_params = params.EdgeTPUBERTCustomParams.from_yaml(
config_path)
self.experiment_params = params.EdgeTPUBERTCustomParams()
self.strategy = tf.distribute.get_strategy()
self.experiment_params.train_datasest.input_path = 'dummy'
self.experiment_params.eval_dataset.input_path = 'dummy'
......
......@@ -18,7 +18,7 @@ r"""Export tflite for MobileBERT-EdgeTPU with SQUAD head.
Example usage:
python3 export_tflite_squad.py \
--config_file=third_party/tensorflow_models/official/projects/edgetpu/nlp/experiments/mobilebert_edgetpu_xs.yaml \
--config_file=official/projects/edgetpu/nlp/experiments/mobilebert_edgetpu_xs.yaml \
--export_path=/tmp/ \
--quantization_method=full-integer
"""
......
......@@ -16,7 +16,6 @@
from absl import flags
import tensorflow as tf
import yaml
from official.projects.edgetpu.nlp.configs import params
from official.projects.edgetpu.nlp.modeling import model_builder
......@@ -48,25 +47,21 @@ class UtilsTest(tf.test.TestCase):
def test_config_override(self):
# Define several dummy flags which are call by the utils.config_override
# function.
file_path = 'third_party/tensorflow_models/official/projects/edgetpu/nlp/experiments/mobilebert_edgetpu_m.yaml'
flags.DEFINE_string('tpu', None, 'tpu_address.')
flags.DEFINE_list('config_file', [file_path],
flags.DEFINE_list('config_file', [],
'A list of config files path.')
flags.DEFINE_string('params_override', None, 'Override params.')
flags.DEFINE_string('params_override',
'orbit_config.mode=eval', 'Override params.')
flags.DEFINE_string('model_dir', '/tmp/', 'Model saving directory.')
flags.DEFINE_list('mode', ['train'], 'Job mode.')
flags.DEFINE_bool('use_vizier', False,
'Whether to enable vizier based hyperparameter search.')
experiment_params = params.EdgeTPUBERTCustomParams()
# By default, the orbit is set with train mode.
self.assertEqual(experiment_params.orbit_config.mode, 'train')
# Config override should set the orbit to eval mode.
experiment_params = utils.config_override(experiment_params, FLAGS)
experiment_params_dict = experiment_params.as_dict()
with tf.io.gfile.GFile(file_path, 'r') as f:
loaded_dict = yaml.load(f, Loader=yaml.FullLoader)
# experiment_params contains all the configs but the loaded_dict might
# only contains partial of the configs.
self.assertTrue(nested_dict_compare(loaded_dict, experiment_params_dict))
self.assertEqual(experiment_params.orbit_config.mode, 'eval')
def test_load_checkpoint(self):
"""Test the pretrained model can be successfully loaded."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment