continuous_finetune_lib_test.py 3.11 KB
Newer Older
Hongkun Yu's avatar
Hongkun Yu committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Le Hou's avatar
Le Hou committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frederick Liu's avatar
Frederick Liu committed
14

Le Hou's avatar
Le Hou committed
15
16
17
18
import os

from absl import flags
from absl.testing import flagsaver
Hongkun Yu's avatar
Hongkun Yu committed
19
from absl.testing import parameterized
Le Hou's avatar
Le Hou committed
20
import tensorflow as tf
Hongkun Yu's avatar
Hongkun Yu committed
21
22
23
24

# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
Le Hou's avatar
Le Hou committed
25
26
27
28
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
Hongkun Yu's avatar
Hongkun Yu committed
29
from official.nlp import continuous_finetune_lib
Le Hou's avatar
Le Hou committed
30
31
32
33
34
35

FLAGS = flags.FLAGS

tfm_flags.define_flags()


Hongkun Yu's avatar
Hongkun Yu committed
36
class ContinuousFinetuneTest(tf.test.TestCase, parameterized.TestCase):
Le Hou's avatar
Le Hou committed
37
38

  def setUp(self):
Hongkun Yu's avatar
Hongkun Yu committed
39
    super().setUp()
Le Hou's avatar
Le Hou committed
40
41
    self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')

Hongkun Yu's avatar
Hongkun Yu committed
42
43
  def testContinuousFinetune(self):
    pretrain_steps = 1
Le Hou's avatar
Le Hou committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    src_model_dir = self.get_temp_dir()
    flags_dict = dict(
        experiment='mock',
        mode='continuous_train_and_eval',
        model_dir=self._model_dir,
        params_override={
            'task': {
                'init_checkpoint': src_model_dir,
            },
            'trainer': {
                'continuous_eval_timeout': 1,
                'steps_per_loop': 1,
                'train_steps': 1,
                'validation_steps': 1,
                'best_checkpoint_export_subdir': 'best_ckpt',
                'best_checkpoint_eval_metric': 'acc',
                'optimizer_config': {
                    'optimizer': {
                        'type': 'sgd'
                    },
                    'learning_rate': {
                        'type': 'constant'
                    }
                }
            }
        })

    with flagsaver.flagsaver(**flags_dict):
      # Train and save some checkpoints.
      params = train_utils.parse_configuration(flags.FLAGS)
      distribution_strategy = tf.distribute.get_strategy()
      with distribution_strategy.scope():
        task = task_factory.get_task(params.task, logging_dir=src_model_dir)
      _ = train_lib.run_experiment(
          distribution_strategy=distribution_strategy,
          task=task,
          mode='train',
          params=params,
          model_dir=src_model_dir)

      params = train_utils.parse_configuration(FLAGS)
Hongkun Yu's avatar
Hongkun Yu committed
85
      eval_metrics = continuous_finetune_lib.run_continuous_finetune(
Hongkun Yu's avatar
Hongkun Yu committed
86
87
88
89
90
          FLAGS.mode,
          params,
          FLAGS.model_dir,
          run_post_eval=True,
          pretrain_steps=pretrain_steps)
Le Hou's avatar
Le Hou committed
91
92
      self.assertIn('best_acc', eval_metrics)

Le Hou's avatar
Le Hou committed
93
94
95
      self.assertFalse(
          tf.io.gfile.exists(os.path.join(FLAGS.model_dir, 'checkpoint')))

Le Hou's avatar
Le Hou committed
96
97
98

if __name__ == '__main__':
  tf.test.main()