Commit c9df0dd4 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Fix unit test.

PiperOrigin-RevId: 348228865
parent 275cb4e8
...@@ -17,6 +17,8 @@ import os ...@@ -17,6 +17,8 @@ import os
from absl import flags from absl import flags
from absl.testing import parameterized from absl.testing import parameterized
import dataclasses
import orbit
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
...@@ -27,17 +29,83 @@ from official.common import registry_imports ...@@ -27,17 +29,83 @@ from official.common import registry_imports
# pylint: enable=unused-import # pylint: enable=unused-import
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import task_factory from official.core import task_factory
from official.modeling import optimization
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.modeling.progressive import policies
from official.modeling.progressive import train_lib from official.modeling.progressive import train_lib
from official.modeling.progressive import trainer as prog_trainer_lib from official.modeling.progressive import trainer as prog_trainer_lib
from official.nlp.data import pretrain_dataloader from official.utils.testing import mock_task
from official.nlp.tasks import progressive_masked_lm
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
tfm_flags.define_flags() tfm_flags.define_flags()
@dataclasses.dataclass
class ProgTaskConfig(cfg.TaskConfig):
pass
@task_factory.register_task_cls(ProgTaskConfig)
class ProgMockTask(policies.ProgressivePolicy, mock_task.MockTask):
"""Progressive task for testing."""
def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
mock_task.MockTask.__init__(
self, params=params, logging_dir=logging_dir)
policies.ProgressivePolicy.__init__(self)
def num_stages(self):
return 2
def num_steps(self, stage_id):
return 2 if stage_id == 0 else 4
def get_model(self, stage_id, old_model=None):
del stage_id, old_model
return self.build_model()
def get_optimizer(self, stage_id):
"""Build optimizer for each stage."""
params = optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.01,
'end_learning_rate': 0.0,
'power': 1.0,
'decay_steps': 10,
},
},
'warmup': {
'polynomial': {
'power': 1,
'warmup_steps': 2,
},
'type': 'polynomial',
}
})
opt_factory = optimization.OptimizerFactory(params)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
return optimizer
def get_train_dataset(self, stage_id):
del stage_id
strategy = tf.distribute.get_strategy()
return orbit.utils.make_distributed_dataset(
strategy, self.build_inputs, None)
def get_eval_dataset(self, stage_id):
del stage_id
strategy = tf.distribute.get_strategy()
return orbit.utils.make_distributed_dataset(
strategy, self.build_inputs, None)
class TrainTest(tf.test.TestCase, parameterized.TestCase): class TrainTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
...@@ -76,12 +144,7 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase): ...@@ -76,12 +144,7 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
model_dir = self.get_temp_dir() model_dir = self.get_temp_dir()
experiment_config = cfg.ExperimentConfig( experiment_config = cfg.ExperimentConfig(
trainer=prog_trainer_lib.ProgressiveTrainerConfig(), trainer=prog_trainer_lib.ProgressiveTrainerConfig(),
task=progressive_masked_lm.ProgMaskedLMConfig( task=ProgTaskConfig())
train_data=pretrain_dataloader.BertPretrainDataConfig(
input_path='dummy'),
validation_data=pretrain_dataloader.BertPretrainDataConfig(
is_training=False,
input_path='dummy')))
experiment_config = params_dict.override_params_dict( experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False) experiment_config, self._test_config, is_strict=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment