Commit 31ce1788 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 405501883
parent 26d8e705
...@@ -45,9 +45,7 @@ class EdgetpuBertTrainerTest(tf.test.TestCase): ...@@ -45,9 +45,7 @@ class EdgetpuBertTrainerTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super(EdgetpuBertTrainerTest, self).setUp() super(EdgetpuBertTrainerTest, self).setUp()
config_path = 'third_party/tensorflow_models/official/projects/edgetpu/nlp/experiments/mobilebert_edgetpu_m.yaml' self.experiment_params = params.EdgeTPUBERTCustomParams()
self.experiment_params = params.EdgeTPUBERTCustomParams.from_yaml(
config_path)
self.strategy = tf.distribute.get_strategy() self.strategy = tf.distribute.get_strategy()
self.experiment_params.train_datasest.input_path = 'dummy' self.experiment_params.train_datasest.input_path = 'dummy'
self.experiment_params.eval_dataset.input_path = 'dummy' self.experiment_params.eval_dataset.input_path = 'dummy'
......
...@@ -18,7 +18,7 @@ r"""Export tflite for MobileBERT-EdgeTPU with SQUAD head. ...@@ -18,7 +18,7 @@ r"""Export tflite for MobileBERT-EdgeTPU with SQUAD head.
Example usage: Example usage:
python3 export_tflite_squad.py \ python3 export_tflite_squad.py \
--config_file=third_party/tensorflow_models/official/projects/edgetpu/nlp/experiments/mobilebert_edgetpu_xs.yaml \ --config_file=official/projects/edgetpu/nlp/experiments/mobilebert_edgetpu_xs.yaml \
--export_path=/tmp/ \ --export_path=/tmp/ \
--quantization_method=full-integer --quantization_method=full-integer
""" """
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
import yaml
from official.projects.edgetpu.nlp.configs import params from official.projects.edgetpu.nlp.configs import params
from official.projects.edgetpu.nlp.modeling import model_builder from official.projects.edgetpu.nlp.modeling import model_builder
...@@ -48,25 +47,21 @@ class UtilsTest(tf.test.TestCase): ...@@ -48,25 +47,21 @@ class UtilsTest(tf.test.TestCase):
def test_config_override(self): def test_config_override(self):
# Define several dummy flags which are call by the utils.config_override # Define several dummy flags which are call by the utils.config_override
# function. # function.
file_path = 'third_party/tensorflow_models/official/projects/edgetpu/nlp/experiments/mobilebert_edgetpu_m.yaml'
flags.DEFINE_string('tpu', None, 'tpu_address.') flags.DEFINE_string('tpu', None, 'tpu_address.')
flags.DEFINE_list('config_file', [file_path], flags.DEFINE_list('config_file', [],
'A list of config files path.') 'A list of config files path.')
flags.DEFINE_string('params_override', None, 'Override params.') flags.DEFINE_string('params_override',
'orbit_config.mode=eval', 'Override params.')
flags.DEFINE_string('model_dir', '/tmp/', 'Model saving directory.') flags.DEFINE_string('model_dir', '/tmp/', 'Model saving directory.')
flags.DEFINE_list('mode', ['train'], 'Job mode.') flags.DEFINE_list('mode', ['train'], 'Job mode.')
flags.DEFINE_bool('use_vizier', False, flags.DEFINE_bool('use_vizier', False,
'Whether to enable vizier based hyperparameter search.') 'Whether to enable vizier based hyperparameter search.')
experiment_params = params.EdgeTPUBERTCustomParams() experiment_params = params.EdgeTPUBERTCustomParams()
# By default, the orbit is set with train mode.
self.assertEqual(experiment_params.orbit_config.mode, 'train')
# Config override should set the orbit to eval mode.
experiment_params = utils.config_override(experiment_params, FLAGS) experiment_params = utils.config_override(experiment_params, FLAGS)
experiment_params_dict = experiment_params.as_dict() self.assertEqual(experiment_params.orbit_config.mode, 'eval')
with tf.io.gfile.GFile(file_path, 'r') as f:
loaded_dict = yaml.load(f, Loader=yaml.FullLoader)
# experiment_params contains all the configs but the loaded_dict might
# only contains partial of the configs.
self.assertTrue(nested_dict_compare(loaded_dict, experiment_params_dict))
def test_load_checkpoint(self): def test_load_checkpoint(self):
"""Test the pretrained model can be successfully loaded.""" """Test the pretrained model can be successfully loaded."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment