"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "86f9b582d5d8ccd44ba6bb8daddea5774209ef7c"
Commit 7f596d87 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Open source the progressive training library.

PiperOrigin-RevId: 348113609
parent d9a3b7f0
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base ProgressivePolicy definition for progressive training.
To write a progressive model, subclass ProgressivePolicy and implement its
abstract methods to handle each training stage.
"""
import abc
from typing import Any, Mapping
from absl import logging
import dataclasses
import six
import tensorflow as tf
from official.modeling.hyperparams import base_config
from official.modeling.progressive import utils
@dataclasses.dataclass
class ProgressiveConfig(base_config.Config):
pass
@six.add_metaclass(abc.ABCMeta)
class ProgressivePolicy:
"""The APIs for handling progressive training stages.
Attributes:
cur_model: The model for the current progressive training stage.
cur_train_dataset: The train dataset function for the current stage.
cur_eval_dataset: The eval dataset function for the current stage.
cur_optimizer: The optimizer for the current stage.
cur_checkpoint_items: Items to be saved in and restored from checkpoints,
for the progressive trainer.
is_last_stage: Whether it is currently in the last stage.
Interfaces:
is_stage_advancing: Returns if progressive training is advancing to the
next stage.
update_pt_stage: Update progressive training stage.
"""
def __init__(self):
"""Initialize stage policy."""
self._cur_train_dataset = None
self._cur_eval_dataset = None
self._volatiles = utils.VolatileTrackable(optimizer=None, model=None)
stage_id = 0
self._stage_id = tf.Variable(
stage_id,
trainable=False,
dtype=tf.int64,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
shape=[])
self._volatiles.reassign_trackable(
optimizer=self.get_optimizer(stage_id),
model=self.get_model(stage_id, old_model=None))
def compute_stage_id(self, global_step: int) -> int:
for stage_id in range(self.num_stages()):
global_step -= self.num_steps(stage_id)
if global_step < 0:
return stage_id
logging.error('Global step %d found no matching progressive stages. '
'Default to the last stage.', global_step)
return self.num_stages() - 1
@abc.abstractmethod
def num_stages(self) -> int:
"""Return the total number of progressive stages."""
pass
@abc.abstractmethod
def num_steps(self, stage_id: int) -> int:
"""Return the total number of steps in this stage."""
pass
@abc.abstractmethod
def get_model(self,
stage_id: int,
old_model: tf.keras.Model = None) -> tf.keras.Model:
"""Return model for this stage. For initialization, `old_model` = None."""
pass
@abc.abstractmethod
def get_optimizer(self, stage_id: int) -> tf.keras.optimizers.Optimizer:
"""Return optimizer for this stage."""
pass
@abc.abstractmethod
def get_train_dataset(self, stage_id: int) -> tf.data.Dataset:
"""Return training Dataset for this stage."""
pass
@abc.abstractmethod
def get_eval_dataset(self, stage_id: int) -> tf.data.Dataset:
"""Return evaluation Dataset for this stage."""
pass
@property
def cur_model(self) -> tf.keras.Model:
return self._volatiles.model
@property
def cur_train_dataset(self) -> tf.data.Dataset:
if self._cur_train_dataset is None:
self._cur_train_dataset = self.get_train_dataset(self._stage_id.numpy())
return self._cur_train_dataset
@property
def cur_eval_dataset(self) -> tf.data.Dataset:
if self._cur_eval_dataset is None:
self._cur_eval_dataset = self.get_eval_dataset(self._stage_id.numpy())
return self._cur_eval_dataset
@property
def cur_optimizer(self) -> tf.keras.optimizers.Optimizer:
return self._volatiles.optimizer
@property
def is_last_stage(self) -> bool:
stage_id = self._stage_id.numpy()
return stage_id >= self.num_stages() - 1
@property
def cur_checkpoint_items(self) -> Mapping[str, Any]:
return dict(stage_id=self._stage_id, volatiles=self._volatiles)
def is_stage_advancing(self, global_step: int) -> bool:
old_stage_id = self._stage_id.numpy()
new_stage_id = self.compute_stage_id(global_step)
return old_stage_id != new_stage_id
def update_pt_stage(self, global_step: int, pass_old_model=True) -> None:
"""Update progressive training internal status.
Call this after a training loop ends.
Args:
global_step: an integer scalar of the current global step.
pass_old_model: whether to pass the old_model to get_model() function.
This is set to False if the old_model is irrelevant (e.g, just a default
model from stage 0).
"""
old_stage_id = self._stage_id.numpy()
new_stage_id = self.compute_stage_id(global_step)
logging.info('Switching stage from %d to %d', old_stage_id, new_stage_id)
# Update stage id.
self._stage_id.assign(new_stage_id)
# Update dataset function.
self._cur_train_dataset = None
self._cur_eval_dataset = None
# Update optimizer and model.
new_optimizer = self.get_optimizer(new_stage_id)
self._volatiles.reassign_trackable(optimizer=new_optimizer)
new_model = self.get_model(
new_stage_id, old_model=self.cur_model if pass_old_model else None)
self._volatiles.reassign_trackable(model=new_model)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFM binary for the progressive trainer."""
from absl import app
from absl import flags
import gin
from official.common import distribute_utils
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_utils
from official.modeling import performance
from official.modeling.progressive import train_lib
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu,
**params.runtime.model_parallelism())
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFM progressive training driver library.
Compared to the common training driver, the only difference is that we use
prog_trainer_lib.ProgressiveTrainer instead of the base trainer.
"""
# pytype: disable=attribute-error
import os
from typing import Any, Mapping, Tuple
# Import libraries
from absl import logging
import orbit
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions
from official.core import train_lib as base_train_lib
from official.modeling.progressive import trainer as prog_trainer_lib
def run_experiment(distribution_strategy: tf.distribute.Strategy,
task: base_task.Task,
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True) \
-> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
with distribution_strategy.scope():
logging.info('Running progressive trainer.')
trainer = prog_trainer_lib.ProgressiveTrainer(
params, task, ckpt_dir=model_dir,
train='train' in mode,
evaluate=('eval' in mode) or run_post_eval,
checkpoint_exporter=base_train_lib.maybe_create_best_ckpt_exporter(
params, model_dir))
if trainer.checkpoint:
checkpoint_manager = tf.train.CheckpointManager(
trainer.checkpoint,
directory=model_dir,
max_to_keep=params.trainer.max_to_keep,
step_counter=trainer.global_step,
checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=trainer.initialize)
else:
checkpoint_manager = None
controller = orbit.Controller(
strategy=distribution_strategy,
trainer=trainer if 'train' in mode else None,
evaluator=trainer,
global_step=trainer.global_step,
steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train') if (save_summary) else None,
eval_summary_dir=os.path.join(model_dir, 'validation') if
(save_summary) else None,
summary_interval=params.trainer.summary_interval if
(save_summary) else None)
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
if mode == 'train':
controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
controller.train_and_evaluate(
train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval)
elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if trainer.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
if run_post_eval:
with distribution_strategy.scope():
return trainer.model, trainer.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps))
else:
return trainer.model, {}
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the progressive train_lib."""
import os
from absl import flags
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.common import flags as tfm_flags
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling.hyperparams import params_dict
from official.modeling.progressive import train_lib
from official.modeling.progressive import trainer as prog_trainer_lib
from official.nlp.data import pretrain_dataloader
from official.nlp.tasks import progressive_masked_lm
FLAGS = flags.FLAGS
tfm_flags.define_flags()
class TrainTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainTest, self).setUp()
self._test_config = {
'trainer': {
'checkpoint_interval': 10,
'steps_per_loop': 10,
'summary_interval': 10,
'train_steps': 10,
'validation_steps': 5,
'validation_interval': 10,
'continuous_eval_timeout': 1,
'optimizer_config': {
'optimizer': {
'type': 'sgd',
},
'learning_rate': {
'type': 'constant'
}
}
},
}
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode='eager',
flag_mode=['train', 'eval', 'train_and_eval'],
run_post_eval=[True, False]))
def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
model_dir = self.get_temp_dir()
experiment_config = cfg.ExperimentConfig(
trainer=prog_trainer_lib.ProgressiveTrainerConfig(),
task=progressive_masked_lm.ProgMaskedLMConfig(
train_data=pretrain_dataloader.BertPretrainDataConfig(
input_path='dummy'),
validation_data=pretrain_dataloader.BertPretrainDataConfig(
is_training=False,
input_path='dummy')))
experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False)
with distribution_strategy.scope():
task = task_factory.get_task(experiment_config.task,
logging_dir=model_dir)
_, logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=experiment_config,
model_dir=model_dir,
run_post_eval=run_post_eval)
if run_post_eval:
self.assertNotEmpty(logs)
else:
self.assertEmpty(logs)
if flag_mode == 'eval':
return
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
# Tests continuous evaluation.
_, logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='continuous_eval',
params=experiment_config,
model_dir=model_dir,
run_post_eval=run_post_eval)
print(logs)
if __name__ == '__main__':
tf.test.main()
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Progressive Trainer implementation.
The trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
"""
import os
from typing import Any, Optional
# Import libraries
from absl import logging
import dataclasses
import gin
import orbit
import tensorflow as tf
from official.core import base_task
from official.core import base_trainer as trainer_lib
from official.core import config_definitions
from official.modeling.progressive import policies
ExperimentConfig = config_definitions.ExperimentConfig
@dataclasses.dataclass
class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
"""Configuration for progressive trainer.
Attributes:
progressive: A task-specific config. Users can subclass ProgressiveConfig
and define any task-specific settings in their subclass.
export_checkpoint: A bool. Whether to export checkpoints in non-progressive
manner (without the volatiles wrapper) such that your down-stream tasks
can load checkpoints from a progressive trainer as if it is a regular
checkpoint.
export_checkpoint_interval: A bool. The number of steps between exporting
checkpoints. If None (by default), will use the same value as
TrainerConfig.checkpoint_interval.
export_only_final_stage_ckpt: A bool. Whether to just export checkpoints
during the final progressive training stage. In other words, whether to
not export small, partial models. In many cases, it is not meaningful to
finetune a small, partial model in down-stream tasks.
"""
progressive: Optional[policies.ProgressiveConfig] = None
export_checkpoint: bool = True
export_checkpoint_interval: Optional[int] = None
export_only_final_stage_ckpt: bool = True
class CheckpointWithHooks(tf.train.Checkpoint):
"""Same as tf.train.Checkpoint but supports hooks.
When running continuous_eval jobs, when a new checkpoint arrives, we have to
update our model and optimizer etc. to match the stage_id of the checkpoint.
However, when orbit loads a checkpoint, it does not inform us. So we use this
class to update our model to the correct stage before checkpoint restore.
"""
def __init__(self, before_load_hook, **kwargs):
self._before_load_hook = before_load_hook
super(CheckpointWithHooks, self).__init__(**kwargs)
# override
def read(self, save_path, options=None):
self._before_load_hook(save_path)
logging.info('Ran before_load_hook.')
super(CheckpointWithHooks, self).read(save_path=save_path, options=options)
@gin.configurable
class ProgressiveTrainer(trainer_lib.Trainer):
"""Implements the progressive trainer shared for TensorFlow models."""
def __init__(
self,
config: ExperimentConfig,
prog_task: base_task.Task, # also implemented ProgressivePolicy.
ckpt_dir: str = '',
train: bool = True,
evaluate: bool = True,
checkpoint_exporter: Any = None):
"""Initialize common trainer for TensorFlow models.
Args:
config: An `ExperimentConfig` instance specifying experiment config.
prog_task: An instance both implemented policies.ProgressivePolicy and
base_task.Task.
ckpt_dir: Checkpoint directory.
train: bool, whether or not this trainer will be used for training.
default to True.
evaluate: bool, whether or not this trainer will be used for evaluation.
default to True.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy()
self._config = config
self._task = prog_task
# Directory for non-progressive checkpoint
self._export_ckpt_dir = os.path.join(ckpt_dir, 'exported_ckpts')
tf.io.gfile.makedirs(self._export_ckpt_dir)
# Receive other checkpoint export, e.g, best checkpoint exporter.
# TODO(lehou): unify the checkpoint exporting logic, although the default
# setting does not use checkpoint_exporter.
self._checkpoint_exporter = checkpoint_exporter
self._global_step = orbit.utils.create_global_step()
self._checkpoint = CheckpointWithHooks(
before_load_hook=self._update_pt_stage_from_ckpt,
global_step=self.global_step,
**self._task.cur_checkpoint_items)
self._train_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
self._validation_loss = tf.keras.metrics.Mean(
'validation_loss', dtype=tf.float32)
self._train_metrics = self.task.build_metrics(
training=True) + self.model.metrics
self._validation_metrics = self.task.build_metrics(
training=False) + self.model.metrics
if train:
orbit.StandardTrainer.__init__(
self,
None, # Manage train_dataset by ourselves, not by StandardTrainer.
options=orbit.StandardTrainerOptions(
use_tf_while_loop=config.trainer.train_tf_while_loop,
use_tf_function=config.trainer.train_tf_function))
if evaluate:
orbit.StandardEvaluator.__init__(
self,
None, # Manage train_dataset by ourselves, not by StandardEvaluator.
options=orbit.StandardEvaluatorOptions(
use_tf_function=config.trainer.eval_tf_function))
@property
def model(self):
return self._task.cur_model
@property
def optimizer(self):
return self._task.cur_optimizer
# override
@property
def train_dataset(self):
"""Overriding StandardTrainer.train_dataset."""
return self._task.cur_train_dataset
# override
@train_dataset.setter
def train_dataset(self, _):
raise SyntaxError('Please do not set train_dataset. Progressive training '
'relies on progressive policy to manager train dataset.')
# override
@property
def eval_dataset(self):
"""Overriding StandardEvaluator.eval_dataset."""
return self._task.cur_eval_dataset
# override
@eval_dataset.setter
def eval_dataset(self, _):
raise SyntaxError('Please do not set eval_dataset. Progressive training '
'relies on progressive policy to manager eval dataset.')
def train_loop_end(self):
"""See base class."""
logs = {}
for metric in self.train_metrics + [self.train_loss]:
logs[metric.name] = metric.result()
metric.reset_states()
if callable(self.optimizer.learning_rate):
logs['learning_rate'] = self.optimizer.learning_rate(
self.optimizer.iterations)
else:
logs['learning_rate'] = self.optimizer.learning_rate
self._maybe_export_non_progressive_checkpoint(self._export_ckpt_dir)
if self._task.is_stage_advancing(self.global_step.numpy()):
old_train_dataset = self.train_dataset
# Update progressive properties
self._task.update_pt_stage(self.global_step.numpy())
# Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will
# rebuild the train and eval functions with the updated model.
self._train_loop_fn = None
self._eval_loop_fn = None
if self.train_dataset != old_train_dataset:
# Setting `self._train_iter` to None will rebuild the dataset iterator.
self._train_iter = None
return logs
def _update_pt_stage_from_ckpt(self, ckpt_file):
"""Update stage properties based on the global_step variable in a ckpt file.
Before loading variables from a checkpoint file, we need to go to the
correct stage and build corresponding model and optimizer, to make sure that
we retore variables of the right model and optimizer.
Args:
ckpt_file: Checkpoint file that will be restored/read from.
"""
if not ckpt_file:
return
ckpt = tf.train.Checkpoint(global_step=self.global_step)
ckpt.read(ckpt_file).expect_partial().assert_existing_objects_matched()
if self._task.is_stage_advancing(self.global_step.numpy()):
old_train_dataset = self.train_dataset
# Update progressive properties
self._task.update_pt_stage(self.global_step.numpy(), pass_old_model=False)
# Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will
# rebuild the train and eval functions with the updated model.
self._train_loop_fn = None
self._eval_loop_fn = None
if self.train_dataset != old_train_dataset:
# Setting `self._train_iter` to None will rebuild the dataset iterator.
self._train_iter = None
def _maybe_export_non_progressive_checkpoint(self, export_ckpt_dir):
"""Export checkpoints in non-progressive format.
This basically removes the wrapping of self._task.cur_checkpoint_items
-- just save the model, optimizer, etc., directly.
The purpose is to let your down-stream tasks to use these checkpoints.
Args:
export_ckpt_dir: A str. folder of exported checkpoints.
"""
if not self.config.trainer.export_checkpoint:
logging.info('Not exporting checkpoints.')
return
if not self._task.is_last_stage and (
self.config.trainer.export_only_final_stage_ckpt):
logging.info('Not exporting checkpoints until the last stage.')
return
global_step_np = self.global_step.numpy()
if self.config.trainer.export_checkpoint_interval is None:
step_interval = self.config.trainer.checkpoint_interval
else:
step_interval = self.config.trainer.export_checkpoint_interval
if global_step_np % step_interval != 0:
logging.info('Not exporting checkpoints in global step: %d.',
global_step_np)
return
# Create a checkpoint object just now, to make sure we use
# progressive_policy.cur_model and progressive_policy.cur_optimizer of the
# current stage.
if hasattr(self.model, 'checkpoint_items'):
checkpoint_items = self.model.checkpoint_items
else:
checkpoint_items = {}
checkpoint = tf.train.Checkpoint(
global_step=self.global_step,
model=self.model,
optimizer=self.optimizer,
**checkpoint_items)
file_prefix = os.path.join(export_ckpt_dir,
'ckpt-{}'.format(global_step_np))
checkpoint.save(file_prefix=file_prefix)
logging.info('Checkpoints exported: %s.', file_prefix)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the progressive trainer."""
# pylint: disable=g-direct-tensorflow-import
import os
from absl.testing import parameterized
import orbit
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import config_definitions as cfg
from official.modeling import optimization
from official.modeling.progressive import policies
from official.modeling.progressive import trainer as trainer_lib
from official.nlp.configs import bert
from official.utils.testing import mock_task
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode='eager',
)
def get_exp_config():
return cfg.ExperimentConfig(
task=cfg.TaskConfig(
model=bert.PretrainerConfig()),
trainer=trainer_lib.ProgressiveTrainerConfig(
export_checkpoint=True,
export_checkpoint_interval=1,
export_only_final_stage_ckpt=False))
class TestPolicy(policies.ProgressivePolicy, mock_task.MockTask):
"""Just for testing purposes."""
def __init__(self, strategy, task_config, change_train_dataset=True):
self._strategy = strategy
self._change_train_dataset = change_train_dataset
self._my_train_dataset = None
mock_task.MockTask.__init__(self, params=task_config, logging_dir=None)
policies.ProgressivePolicy.__init__(self)
def num_stages(self) -> int:
return 2
def num_steps(self, stage_id: int) -> int:
return 2 if stage_id == 0 else 4
def get_model(self,
stage_id: int,
old_model: tf.keras.Model) -> tf.keras.Model:
del stage_id, old_model
return self.build_model()
def get_optimizer(self, stage_id: int) -> tf.keras.optimizers.Optimizer:
optimizer_type = 'sgd' if stage_id == 0 else 'adamw'
optimizer_config = cfg.OptimizationConfig({
'optimizer': {'type': optimizer_type},
'learning_rate': {'type': 'constant'}})
opt_factory = optimization.OptimizerFactory(optimizer_config)
return opt_factory.build_optimizer(opt_factory.build_learning_rate())
def get_train_dataset(self, stage_id: int) -> tf.data.Dataset:
if not self._change_train_dataset and self._my_train_dataset:
return self._my_train_dataset
if self._strategy:
self._my_train_dataset = orbit.utils.make_distributed_dataset(
self._strategy,
self._build_inputs,
stage_id)
else:
self._my_train_dataset = self._build_inputs(stage_id)
return self._my_train_dataset
def get_eval_dataset(self, stage_id: int) -> tf.data.Dataset:
if self._strategy:
return orbit.utils.make_distributed_dataset(
self._strategy,
self._build_inputs,
stage_id)
return self._build_inputs(stage_id)
def _build_inputs(self, stage_id):
def dummy_data(_):
batch_size = 2 if stage_id == 0 else 1
x = tf.zeros(shape=(batch_size, 2), dtype=tf.float32)
label = tf.zeros(shape=(batch_size, 1), dtype=tf.float32)
return x, label
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
return dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainerTest, self).setUp()
self._config = get_exp_config()
def create_test_trainer(self, distribution, model_dir, change_train_dataset):
trainer = trainer_lib.ProgressiveTrainer(
self._config,
prog_task=TestPolicy(
distribution, self._config.task, change_train_dataset),
ckpt_dir=model_dir)
return trainer
@combinations.generate(all_strategy_combinations())
def test_checkpointing(self, distribution):
model_dir = self.get_temp_dir()
ckpt_file = os.path.join(model_dir, 'ckpt')
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, True)
self.assertFalse(trainer._task.is_last_stage)
trainer.train(tf.convert_to_tensor(4, dtype=tf.int32))
self.assertTrue(trainer._task.is_last_stage)
trainer.checkpoint.save(ckpt_file)
trainer = self.create_test_trainer(distribution, model_dir, True)
self.assertFalse(trainer._task.is_last_stage)
trainer.checkpoint.restore(ckpt_file + '-1')
self.assertTrue(trainer._task.is_last_stage)
@combinations.generate(all_strategy_combinations())
def test_train_dataset(self, distribution):
model_dir = self.get_temp_dir()
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, True)
# Using dataset of stage == 0
train_iter = tf.nest.map_structure(iter, trainer.train_dataset)
train_data = train_iter.next()[0]
if distribution.num_replicas_in_sync > 1:
train_data = train_data.values[0]
self.assertEqual(train_data.shape[0], 2)
trainer.train(tf.convert_to_tensor(4, dtype=tf.int32))
# Using dataset of stage == 1
train_iter = tf.nest.map_structure(iter, trainer.train_dataset)
train_data = train_iter.next()[0]
if distribution.num_replicas_in_sync > 1:
train_data = train_data.values[0]
self.assertEqual(train_data.shape[0], 1)
with self.assertRaises(SyntaxError):
trainer.train_dataset = None
@combinations.generate(all_strategy_combinations())
def test_train_dataset_no_switch(self, distribution):
model_dir = self.get_temp_dir()
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, False)
trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
# _train_iter is not reset since the dataset is not changed.
self.assertIsNotNone(trainer._train_iter)
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, True)
trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
# _train_iter is reset since the dataset changed.
self.assertIsNone(trainer._train_iter)
class TrainerWithMaskedLMTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainerWithMaskedLMTaskTest, self).setUp()
self._config = get_exp_config()
def create_test_trainer(self, distribution):
trainer = trainer_lib.ProgressiveTrainer(
self._config,
prog_task=TestPolicy(distribution, self._config.task),
ckpt_dir=self.get_temp_dir())
return trainer
@combinations.generate(all_strategy_combinations())
def test_trainer_train(self, distribution):
with distribution.scope():
trainer = self.create_test_trainer(distribution)
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
@combinations.generate(all_strategy_combinations())
def test_trainer_validate(self, distribution):
with distribution.scope():
trainer = self.create_test_trainer(distribution)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('validation_loss', logs)
self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
@combinations.generate(
combinations.combine(
mixed_precision_dtype=['float32', 'bfloat16', 'float16'],
loss_scale=[None, 'dynamic', 128, 256],
))
def test_configure_optimizer(self, mixed_precision_dtype, loss_scale):
config = cfg.ExperimentConfig(
task=cfg.TaskConfig(
model=bert.PretrainerConfig()),
runtime=cfg.RuntimeConfig(
mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale),
trainer=trainer_lib.ProgressiveTrainerConfig(
export_checkpoint=True,
export_checkpoint_interval=1,
export_only_final_stage_ckpt=False))
task = TestPolicy(None, config.task)
trainer = trainer_lib.ProgressiveTrainer(config, task, self.get_temp_dir())
if mixed_precision_dtype != 'float16':
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
elif mixed_precision_dtype == 'float16' and loss_scale is None:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics)
if __name__ == '__main__':
tf.test.main()
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Util classes and functions."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.training.tracking import tracking
class VolatileTrackable(tracking.AutoTrackable):
"""A util class to keep Trackables that might change instances."""
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def reassign_trackable(self, **kwargs):
for k, v in kwargs.items():
delattr(self, k) # untrack this object
setattr(self, k, v) # track the new object
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Masked language task with progressive training."""
from typing import List
# Import libraries
from absl import logging
import dataclasses
import orbit
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling import optimization
from official.modeling.hyperparams import base_config
from official.modeling.progressive import policies
from official.nlp.tasks import masked_lm
@dataclasses.dataclass
class StackingStageConfig(base_config.Config):
num_layers: int = 0
num_steps: int = 0
warmup_steps: int = 10000
initial_learning_rate: float = 1e-4
end_learning_rate: float = 0.0
decay_steps: int = 1000000
@dataclasses.dataclass
class ProgMaskedLMConfig(masked_lm.MaskedLMConfig):
"""The progressive model config."""
optimizer_config: optimization.OptimizationConfig = (
optimization.OptimizationConfig(
optimizer=optimization.OptimizerConfig(type='adamw'),
learning_rate=optimization.LrConfig(type='polynomial'),
warmup=optimization.WarmupConfig(type='polynomial'),
)
)
stage_list: List[StackingStageConfig] = dataclasses.field(
default_factory=lambda: [ # pylint: disable=g-long-lambda
StackingStageConfig(num_layers=3,
num_steps=112500,
warmup_steps=10000,
initial_learning_rate=1e-4,
end_learning_rate=1e-4,
decay_steps=112500),
StackingStageConfig(num_layers=6,
num_steps=112500,
warmup_steps=10000,
initial_learning_rate=1e-4,
end_learning_rate=1e-4,
decay_steps=112500),
StackingStageConfig(num_layers=12,
num_steps=450000,
warmup_steps=10000,
initial_learning_rate=1e-4,
end_learning_rate=0.0,
decay_steps=450000)])
@task_factory.register_task_cls(ProgMaskedLMConfig)
class ProgressiveMaskedLM(policies.ProgressivePolicy, masked_lm.MaskedLMTask):
"""Masked Language Model that supports progressive training.
Inherate from the MaskedLmTask class to build model datasets etc.
"""
def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
masked_lm.MaskedLMTask.__init__(
self, params=params, logging_dir=logging_dir)
self._model_config = params.model
self._optimizer_config = params.optimizer_config
self._the_only_train_dataset = None
self._the_only_eval_dataset = None
policies.ProgressivePolicy.__init__(self)
# Override
def num_stages(self):
return len(self.task_config.stage_list)
# Override
def num_steps(self, stage_id):
return self.task_config.stage_list[stage_id].num_steps
# Override
def get_model(self, stage_id, old_model=None):
"""Build model for each stage."""
num_layers = self.task_config.stage_list[stage_id].num_layers
encoder_type = self._model_config.encoder.type
params = self._model_config.replace(
encoder={encoder_type: {
'num_layers': num_layers
}})
model = self.build_model(params)
# Run the model once, to make sure that all layers are built.
# Otherwise, not all weights will be copied.
_ = model(model.inputs)
if stage_id > 0 and old_model is not None:
logging.info('Stage %d copying weights.', stage_id)
self._copy_weights_to_new_model(old_model=old_model,
new_model=model)
return model
# Override
def get_optimizer(self, stage_id):
"""Build optimizer for each stage."""
params = self._optimizer_config.replace(
learning_rate={
'polynomial':
{'decay_steps':
self.task_config.stage_list[
stage_id].decay_steps,
'initial_learning_rate':
self.task_config.stage_list[
stage_id].initial_learning_rate,
'end_learning_rate':
self.task_config.stage_list[
stage_id].end_learning_rate,
'power': 1,
'cycle': False,
}
},
warmup={
'polynomial':
{'warmup_steps':
self.task_config.stage_list[stage_id].warmup_steps,
'power': 1,
}
}
)
opt_factory = optimization.OptimizerFactory(params)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
return optimizer
# overrides policies.ProgressivePolicy
def get_train_dataset(self, stage_id):
del stage_id
if self._the_only_train_dataset is None:
strategy = tf.distribute.get_strategy()
self._the_only_train_dataset = orbit.utils.make_distributed_dataset(
strategy,
self.build_inputs,
self.task_config.train_data)
return self._the_only_train_dataset
# overrides policies.ProgressivePolicy
def get_eval_dataset(self, stage_id):
del stage_id
if self._the_only_eval_dataset is None:
strategy = tf.distribute.get_strategy()
self._the_only_eval_dataset = orbit.utils.make_distributed_dataset(
strategy,
self.build_inputs,
self.task_config.validation_data)
return self._the_only_eval_dataset
def _copy_weights_to_new_model(self, old_model, new_model):
"""Copy model weights from the previous stage to the next.
Args:
old_model: nlp.modeling.models.bert_pretrainer.BertPretrainerV2. Model of
the previous stage.
new_model: nlp.modeling.models.bert_pretrainer.BertPretrainerV2. Model of
the next stage.
"""
# Copy weights of the embedding layers.
# pylint: disable=protected-access
# When using `encoder_scaffold`, there may be `_embedding_network`.
if hasattr(new_model.encoder_network, '_embedding_network') and hasattr(
old_model.encoder_network, '_embedding_network') and (
new_model.encoder_network._embedding_network is not None):
new_model.encoder_network._embedding_network.set_weights(
old_model.encoder_network._embedding_network.get_weights())
else:
new_model.encoder_network._embedding_layer.set_weights(
old_model.encoder_network._embedding_layer.get_weights())
new_model.encoder_network._position_embedding_layer.set_weights(
old_model.encoder_network._position_embedding_layer.get_weights())
new_model.encoder_network._type_embedding_layer.set_weights(
old_model.encoder_network._type_embedding_layer.get_weights())
new_model.encoder_network._embedding_norm_layer.set_weights(
old_model.encoder_network._embedding_norm_layer.get_weights())
if hasattr(new_model.encoder_network, '_embedding_projection') and hasattr(
old_model.encoder_network, '_embedding_projection'):
if old_model.encoder_network._embedding_projection is not None:
new_model.encoder_network._embedding_projection.set_weights(
old_model.encoder_network._embedding_projection.get_weights())
# pylint: enable=protected-access
# Copy weights of the transformer layers.
# The model can be EncoderScaffold or TransformerEncoder.
if hasattr(old_model.encoder_network, 'hidden_layers'):
old_layer_group = old_model.encoder_network.hidden_layers
elif hasattr(old_model.encoder_network, 'transformer_layers'):
old_layer_group = old_model.encoder_network.transformer_layers
else:
raise ValueError('Unrecognized encoder network: {}'.format(
old_model.encoder_network))
if hasattr(new_model.encoder_network, 'hidden_layers'):
new_layer_group = new_model.encoder_network.hidden_layers
elif hasattr(new_model.encoder_network, 'transformer_layers'):
new_layer_group = new_model.encoder_network.transformer_layers
else:
raise ValueError('Unrecognized encoder network: {}'.format(
new_model.encoder_network))
for new_layer_idx in range(len(new_layer_group)):
old_layer_idx = new_layer_idx % len(old_layer_group)
new_layer_group[new_layer_idx].set_weights(
old_layer_group[old_layer_idx].get_weights())
if old_layer_idx != new_layer_idx:
if hasattr(new_layer_group[new_layer_idx], 'reset_rezero'):
# Reset ReZero's alpha to 0.
new_layer_group[new_layer_idx].reset_rezero()
# Copy weights of the final layer norm (if needed).
# pylint: disable=protected-access
if hasattr(new_model.encoder_network, '_output_layer_norm') and hasattr(
old_model.encoder_network, '_output_layer_norm'):
new_model.encoder_network._output_layer_norm.set_weights(
old_model.encoder_network._output_layer_norm.get_weights())
# pylint: enable=protected-access
# Copy weights of the pooler layer.
new_model.encoder_network.pooler_layer.set_weights(
old_model.encoder_network.pooler_layer.get_weights())
# Copy weights of the classification head.
for idx in range(len(new_model.classification_heads)):
new_model.classification_heads[idx].set_weights(
old_model.classification_heads[idx].get_weights())
# Copy weights of the masked_lm layer.
new_model.masked_lm.set_weights(old_model.masked_lm.get_weights())
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for google.nlp.progressive_masked_lm."""
# Import libraries
from absl.testing import parameterized
import gin
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import config_definitions as cfg
from official.modeling.progressive import trainer as prog_trainer_lib
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader
from official.nlp.tasks import progressive_masked_lm
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode="eager",
)
class ProgressiveMaskedLMTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(ProgressiveMaskedLMTest, self).setUp()
self.task_config = progressive_masked_lm.ProgMaskedLMConfig(
model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=2)),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
]),
train_data=pretrain_dataloader.BertPretrainDataConfig(
input_path="dummy",
max_predictions_per_seq=20,
seq_length=128,
global_batch_size=1),
validation_data=pretrain_dataloader.BertPretrainDataConfig(
input_path="dummy",
max_predictions_per_seq=20,
seq_length=128,
global_batch_size=1),
stage_list=[
progressive_masked_lm.StackingStageConfig(
num_layers=1, num_steps=4),
progressive_masked_lm.StackingStageConfig(
num_layers=2, num_steps=8),
],
)
self.exp_config = cfg.ExperimentConfig(
task=self.task_config,
trainer=prog_trainer_lib.ProgressiveTrainerConfig())
@combinations.generate(all_strategy_combinations())
def test_num_stages(self, distribution):
with distribution.scope():
prog_masked_lm = progressive_masked_lm.ProgressiveMaskedLM(
self.task_config)
self.assertEqual(prog_masked_lm.num_stages(), 2)
self.assertEqual(prog_masked_lm.num_steps(0), 4)
self.assertEqual(prog_masked_lm.num_steps(1), 8)
@combinations.generate(all_strategy_combinations())
def test_weight_copying(self, distribution):
with distribution.scope():
prog_masked_lm = progressive_masked_lm.ProgressiveMaskedLM(
self.task_config)
old_model = prog_masked_lm.get_model(stage_id=0)
for w in old_model.trainable_weights:
w.assign(tf.zeros_like(w) + 0.12345)
new_model = prog_masked_lm.get_model(stage_id=1, old_model=old_model)
for w in new_model.trainable_weights:
self.assertAllClose(w, tf.zeros_like(w) + 0.12345)
gin.parse_config_files_and_bindings(
None, "encoders.build_encoder.encoder_cls = @EncoderScaffold")
with distribution.scope():
prog_masked_lm = progressive_masked_lm.ProgressiveMaskedLM(
self.task_config)
old_model = prog_masked_lm.get_model(stage_id=0)
for w in old_model.trainable_weights:
w.assign(tf.zeros_like(w) + 0.12345)
new_model = prog_masked_lm.get_model(stage_id=1, old_model=old_model)
for w in new_model.trainable_weights:
self.assertAllClose(w, tf.zeros_like(w) + 0.12345)
if __name__ == "__main__":
tf.test.main()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Translation task with progressive training."""
from typing import List
# Import libraries
from absl import logging
import dataclasses
import orbit
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling import optimization
from official.modeling.hyperparams import base_config
from official.modeling.progressive import policies
from official.nlp.modeling import models
from official.nlp.tasks import translation
@dataclasses.dataclass
class StackingStageConfig(base_config.Config):
num_encoder_layers: int = 0
num_decoder_layers: int = 0
num_steps: int = 0
warmup_steps: int = 10000
initial_learning_rate: float = 0.0625
power: float = -0.5
@dataclasses.dataclass
class ProgTranslationConfig(translation.TranslationConfig):
"""The progressive model config."""
model: translation.ModelConfig = translation.ModelConfig(
encoder=translation.EncDecoder(
num_attention_heads=16, intermediate_size=4096),
decoder=translation.EncDecoder(
num_attention_heads=16, intermediate_size=4096),
embedding_width=1024,
padded_decode=True,
decode_max_length=100)
optimizer_config: optimization.OptimizationConfig = (
optimization.OptimizationConfig({
'optimizer': {
'type': 'adam',
'adam': {
'beta_2': 0.997,
'epsilon': 1e-9,
},
},
'learning_rate': {
'type': 'power',
'power': {
'initial_learning_rate': 0.0625,
'power': -0.5,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 16000,
'warmup_learning_rate': 0.0
}
}
}))
stage_list: List[StackingStageConfig] = dataclasses.field(
default_factory=lambda: [ # pylint: disable=g-long-lambda
StackingStageConfig(num_encoder_layers=3,
num_decoder_layers=3,
num_steps=20000,
warmup_steps=5000,
initial_learning_rate=0.0625),
StackingStageConfig(num_encoder_layers=6,
num_decoder_layers=6,
num_steps=20000,
warmup_steps=5000,
initial_learning_rate=0.0625),
StackingStageConfig(num_encoder_layers=12,
num_decoder_layers=12,
num_steps=100000,
warmup_steps=5000,
initial_learning_rate=0.0625)])
@task_factory.register_task_cls(ProgTranslationConfig)
class ProgressiveTranslationTask(policies.ProgressivePolicy,
translation.TranslationTask):
"""Masked Language Model that supports progressive training.
Inherate from the TranslationTask class to build model datasets etc.
"""
def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
translation.TranslationTask.__init__(
self, params=params, logging_dir=logging_dir)
self._model_config = params.model
self._optimizer_config = params.optimizer_config
self._the_only_train_dataset = None
self._the_only_eval_dataset = None
policies.ProgressivePolicy.__init__(self)
# Override
def num_stages(self):
return len(self.task_config.stage_list)
# Override
def num_steps(self, stage_id):
return self.task_config.stage_list[stage_id].num_steps
# Override
def get_model(self, stage_id, old_model=None):
"""Build model for each stage."""
num_encoder_layers = (
self.task_config.stage_list[stage_id].num_encoder_layers)
num_decoder_layers = (
self.task_config.stage_list[stage_id].num_decoder_layers)
params = self._model_config.replace(
encoder={'num_layers': num_encoder_layers},
decoder={'num_layers': num_decoder_layers})
model = self.build_model(params)
# Run the model once, to make sure that all layers are built.
# Otherwise, not all weights will be copied.
inputs = next(tf.nest.map_structure(
iter, self.build_inputs(self.task_config.train_data)))
model(inputs, training=True)
if stage_id > 0 and old_model is not None:
logging.info('Stage %d copying weights.', stage_id)
self._copy_weights_to_new_model(old_model=old_model,
new_model=model)
return model
# Override
def build_model(self, params) -> tf.keras.Model:
"""Creates model architecture."""
model_cfg = params or self.task_config.model
encoder_kwargs = model_cfg.encoder.as_dict()
encoder_layer = models.TransformerEncoder(**encoder_kwargs)
decoder_kwargs = model_cfg.decoder.as_dict()
decoder_layer = models.TransformerDecoder(**decoder_kwargs)
return models.Seq2SeqTransformer(
vocab_size=self._vocab_size,
embedding_width=model_cfg.embedding_width,
dropout_rate=model_cfg.dropout_rate,
padded_decode=model_cfg.padded_decode,
decode_max_length=model_cfg.decode_max_length,
beam_size=model_cfg.beam_size,
alpha=model_cfg.alpha,
encoder_layer=encoder_layer,
decoder_layer=decoder_layer)
# Override
def get_optimizer(self, stage_id):
"""Build optimizer for each stage."""
params = self._optimizer_config.replace(
warmup={
'linear':
{'warmup_steps':
self.task_config.stage_list[stage_id].warmup_steps
},
},
learning_rate={
'power':
{'initial_learning_rate':
self.task_config.stage_list[stage_id].initial_learning_rate
},
},
)
opt_factory = optimization.OptimizerFactory(params)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
return optimizer
# overrides policies.ProgressivePolicy
def get_train_dataset(self, stage_id):
del stage_id
if self._the_only_train_dataset is None:
strategy = tf.distribute.get_strategy()
self._the_only_train_dataset = orbit.utils.make_distributed_dataset(
strategy,
self.build_inputs,
self.task_config.train_data)
return self._the_only_train_dataset
# overrides policies.ProgressivePolicy
def get_eval_dataset(self, stage_id):
del stage_id
if self._the_only_eval_dataset is None:
strategy = tf.distribute.get_strategy()
self._the_only_eval_dataset = orbit.utils.make_distributed_dataset(
strategy,
self.build_inputs,
self.task_config.validation_data)
return self._the_only_eval_dataset
def _copy_weights_to_new_model(self, old_model, new_model):
"""Copy model weights from the previous stage to the next.
Args:
old_model: nlp.modeling.models.bert_pretrainer.BertPretrainerV2. Model of
the previous stage.
new_model: nlp.modeling.models.bert_pretrainer.BertPretrainerV2. Model of
the next stage.
"""
new_model.embedding_lookup.set_weights(
old_model.embedding_lookup.get_weights())
new_model.position_embedding.set_weights(
old_model.position_embedding.get_weights())
new_model.encoder_layer.output_normalization.set_weights(
old_model.encoder_layer.output_normalization.get_weights())
new_model.decoder_layer.output_normalization.set_weights(
old_model.decoder_layer.output_normalization.get_weights())
old_layer_group = old_model.encoder_layer.encoder_layers
new_layer_group = new_model.encoder_layer.encoder_layers
for new_layer_idx in range(len(new_layer_group)):
old_layer_idx = new_layer_idx % len(old_layer_group)
new_layer_group[new_layer_idx].set_weights(
old_layer_group[old_layer_idx].get_weights())
old_layer_group = old_model.decoder_layer.decoder_layers
new_layer_group = new_model.decoder_layer.decoder_layers
for new_layer_idx in range(len(new_layer_group)):
old_layer_idx = new_layer_idx % len(old_layer_group)
new_layer_group[new_layer_idx].set_weights(
old_layer_group[old_layer_idx].get_weights())
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for google.nlp.progressive_translation."""
import os
from absl.testing import parameterized
import tensorflow as tf
from sentencepiece import SentencePieceTrainer
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import config_definitions as cfg
from official.modeling.progressive import trainer as prog_trainer_lib
from official.nlp.data import wmt_dataloader
from official.nlp.tasks import progressive_translation
from official.nlp.tasks import translation
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode="eager",
)
def _generate_line_file(filepath, lines):
with tf.io.gfile.GFile(filepath, "w") as f:
for l in lines:
f.write("{}\n".format(l))
def _generate_record_file(filepath, src_lines, tgt_lines):
writer = tf.io.TFRecordWriter(filepath)
for src, tgt in zip(src_lines, tgt_lines):
example = tf.train.Example(
features=tf.train.Features(
feature={
"en": tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[src.encode()])),
"reverse_en": tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[tgt.encode()])),
}))
writer.write(example.SerializeToString())
writer.close()
def _train_sentencepiece(input_path, vocab_size, model_path, eos_id=1):
argstr = " ".join([
f"--input={input_path}", f"--vocab_size={vocab_size}",
"--character_coverage=0.995",
f"--model_prefix={model_path}", "--model_type=bpe",
"--bos_id=-1", "--pad_id=0", f"--eos_id={eos_id}", "--unk_id=2"
])
SentencePieceTrainer.Train(argstr)
class ProgressiveTranslationTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(ProgressiveTranslationTest, self).setUp()
self._temp_dir = self.get_temp_dir()
src_lines = ["abc ede fg", "bbcd ef a g", "de f a a g"]
tgt_lines = ["dd cc a ef g", "bcd ef a g", "gef cd ba"]
self._record_input_path = os.path.join(self._temp_dir, "train.record")
_generate_record_file(self._record_input_path, src_lines, tgt_lines)
self._sentencepeice_input_path = os.path.join(self._temp_dir, "inputs.txt")
_generate_line_file(self._sentencepeice_input_path, src_lines + tgt_lines)
sentencepeice_model_prefix = os.path.join(self._temp_dir, "sp")
_train_sentencepiece(self._sentencepeice_input_path, 11,
sentencepeice_model_prefix)
self._sentencepeice_model_path = "{}.model".format(
sentencepeice_model_prefix)
encdecoder = translation.EncDecoder(
num_attention_heads=2, intermediate_size=8)
self.task_config = progressive_translation.ProgTranslationConfig(
model=translation.ModelConfig(
encoder=encdecoder,
decoder=encdecoder,
embedding_width=8,
padded_decode=True,
decode_max_length=100),
train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path,
is_training=True,
global_batch_size=24,
static_batch=True,
src_lang="en",
tgt_lang="reverse_en",
max_seq_length=12),
validation_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path,
is_training=False,
global_batch_size=2,
static_batch=True,
src_lang="en",
tgt_lang="reverse_en",
max_seq_length=12),
sentencepiece_model_path=self._sentencepeice_model_path,
stage_list=[
progressive_translation.StackingStageConfig(
num_encoder_layers=1, num_decoder_layers=1, num_steps=4),
progressive_translation.StackingStageConfig(
num_encoder_layers=2, num_decoder_layers=1, num_steps=8),
],
)
self.exp_config = cfg.ExperimentConfig(
task=self.task_config,
trainer=prog_trainer_lib.ProgressiveTrainerConfig())
@combinations.generate(all_strategy_combinations())
def test_num_stages(self, distribution):
with distribution.scope():
prog_translation = progressive_translation.ProgressiveTranslationTask(
self.task_config)
self.assertEqual(prog_translation.num_stages(), 2)
self.assertEqual(prog_translation.num_steps(0), 4)
self.assertEqual(prog_translation.num_steps(1), 8)
@combinations.generate(all_strategy_combinations())
def test_weight_copying(self, distribution):
with distribution.scope():
prog_translation = progressive_translation.ProgressiveTranslationTask(
self.task_config)
old_model = prog_translation.get_model(stage_id=0)
for w in old_model.trainable_weights:
w.assign(tf.zeros_like(w) + 0.12345)
new_model = prog_translation.get_model(stage_id=1, old_model=old_model)
for w in new_model.trainable_weights:
self.assertAllClose(w, tf.zeros_like(w) + 0.12345)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment