Commit d57ba596 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Fix unit test.

PiperOrigin-RevId: 348228865
parent c9c230d9
......@@ -17,6 +17,8 @@ import os
from absl import flags
from absl.testing import parameterized
import dataclasses
import orbit
import tensorflow as tf
from tensorflow.python.distribute import combinations
......@@ -27,17 +29,83 @@ from official.common import registry_imports
# pylint: enable=unused-import
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling import optimization
from official.modeling.hyperparams import params_dict
from official.modeling.progressive import policies
from official.modeling.progressive import train_lib
from official.modeling.progressive import trainer as prog_trainer_lib
from official.nlp.data import pretrain_dataloader
from official.nlp.tasks import progressive_masked_lm
from official.utils.testing import mock_task
FLAGS = flags.FLAGS
tfm_flags.define_flags()
@dataclasses.dataclass
class ProgTaskConfig(cfg.TaskConfig):
pass
@task_factory.register_task_cls(ProgTaskConfig)
class ProgMockTask(policies.ProgressivePolicy, mock_task.MockTask):
"""Progressive task for testing."""
def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
mock_task.MockTask.__init__(
self, params=params, logging_dir=logging_dir)
policies.ProgressivePolicy.__init__(self)
def num_stages(self):
return 2
def num_steps(self, stage_id):
return 2 if stage_id == 0 else 4
def get_model(self, stage_id, old_model=None):
del stage_id, old_model
return self.build_model()
def get_optimizer(self, stage_id):
"""Build optimizer for each stage."""
params = optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.01,
'end_learning_rate': 0.0,
'power': 1.0,
'decay_steps': 10,
},
},
'warmup': {
'polynomial': {
'power': 1,
'warmup_steps': 2,
},
'type': 'polynomial',
}
})
opt_factory = optimization.OptimizerFactory(params)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
return optimizer
def get_train_dataset(self, stage_id):
del stage_id
strategy = tf.distribute.get_strategy()
return orbit.utils.make_distributed_dataset(
strategy, self.build_inputs, None)
def get_eval_dataset(self, stage_id):
del stage_id
strategy = tf.distribute.get_strategy()
return orbit.utils.make_distributed_dataset(
strategy, self.build_inputs, None)
class TrainTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
......@@ -76,12 +144,7 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
model_dir = self.get_temp_dir()
experiment_config = cfg.ExperimentConfig(
trainer=prog_trainer_lib.ProgressiveTrainerConfig(),
task=progressive_masked_lm.ProgMaskedLMConfig(
train_data=pretrain_dataloader.BertPretrainDataConfig(
input_path='dummy'),
validation_data=pretrain_dataloader.BertPretrainDataConfig(
is_training=False,
input_path='dummy')))
task=ProgTaskConfig())
experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment