Commit e1cb663e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Initial release of NHNet: https://arxiv.org/abs/2001.09386

PiperOrigin-RevId: 306256383
parent 057895af
[
{
"urls": [
"http://url_000.html",
"http://url_001.html"
],
"label": "headline 0"
},
{
"urls": [
"http://url_000.html",
"http://url_001.html"
],
"label": "headline 1"
},
{
"urls": [
"http://url_002.html",
"http://url_001.html"
],
"label": "headline 2"
},
{
"urls": [
"http://url_003.html"
],
"label": "headline 3"
}
]
[UNK]
[CLS]
[SEP]
[MASK]
0
1
this
is
a
title
snippet
for
url
main
text
http
www
html
:
//
.
_
headline
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run NHNet model training and eval."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
from absl import logging
from six.moves import zip
import tensorflow as tf
from official.modeling.hyperparams import params_dict
from official.nlp.nhnet import evaluation
from official.nlp.nhnet import input_pipeline
from official.nlp.nhnet import models
from official.nlp.nhnet import optimizer
from official.nlp.transformer import metrics as transformer_metrics
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS
def define_flags():
"""Defines command line flags used by NHNet trainer."""
## Required parameters
flags.DEFINE_enum("mode", "train", ["train", "eval", "train_and_eval"],
"Execution mode.")
flags.DEFINE_string("train_file_pattern", "", "Train file pattern.")
flags.DEFINE_string("eval_file_pattern", "", "Eval file pattern.")
flags.DEFINE_string(
"model_dir", None,
"The output directory where the model checkpoints will be written.")
# Model training specific flags.
flags.DEFINE_enum(
"distribution_strategy", "mirrored", ["tpu", "mirrored"],
"Distribution Strategy type to use for training. `tpu` uses TPUStrategy "
"for running on TPUs, `mirrored` uses GPUs with single host.")
flags.DEFINE_string("tpu", "", "TPU address to connect to.")
flags.DEFINE_string(
"init_checkpoint", None,
"Initial checkpoint (usually from a pre-trained BERT model).")
flags.DEFINE_integer("train_steps", 100000, "Max train steps")
flags.DEFINE_integer("eval_steps", 32, "Number of eval steps per run.")
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
flags.DEFINE_integer("eval_batch_size", 4, "Total batch size for evaluation.")
flags.DEFINE_integer(
"steps_per_loop", 1000,
"Number of steps per graph-mode loop. Only training step "
"happens inside the loop.")
flags.DEFINE_integer("checkpoint_interval", 2000, "Checkpointing interval.")
flags.DEFINE_integer("len_title", 15, "Title length.")
flags.DEFINE_integer("len_passage", 200, "Passage length.")
flags.DEFINE_integer("num_encoder_layers", 12,
"Number of hidden layers of encoder.")
flags.DEFINE_integer("num_decoder_layers", 12,
"Number of hidden layers of decoder.")
flags.DEFINE_string("model_type", "nhnet",
"Model type to choose a model configuration.")
flags.DEFINE_integer(
"num_nhnet_articles", 5,
"Maximum number of articles in NHNet, only used when model_type=nhnet")
flags.DEFINE_string(
"params_override",
default=None,
help=("a YAML/JSON string or a YAML file which specifies additional "
"overrides over the default parameters"))
# pylint: disable=protected-access
class Trainer(tf.keras.Model):
"""A training only model."""
def __init__(self, model, params):
super(Trainer, self).__init__()
self.model = model
self.params = params
self._num_replicas_in_sync = tf.distribute.get_strategy(
).num_replicas_in_sync
def call(self, inputs, mode="train"):
return self.model(inputs, mode)
def train_step(self, inputs):
"""The logic for one training step."""
with tf.GradientTape() as tape:
logits, _, _ = self(inputs, mode="train", training=True)
targets = models.remove_sos_from_seq(inputs["target_ids"],
self.params.pad_token_id)
loss = transformer_metrics.transformer_loss(logits, targets,
self.params.label_smoothing,
self.params.vocab_size)
# Scales the loss, which results in using the average loss across all
# of the replicas for backprop.
scaled_loss = loss / self._num_replicas_in_sync
tvars = self.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
self.optimizer.apply_gradients(list(zip(grads, tvars)))
return {
"training_loss": loss,
"learning_rate": self.optimizer._decayed_lr(var_dtype=tf.float32)
}
class SimpleCheckpoint(tf.keras.callbacks.Callback):
"""Keras callback to save tf.train.Checkpoints."""
def __init__(self, checkpoint_manager):
super(SimpleCheckpoint, self).__init__()
self.checkpoint_manager = checkpoint_manager
def on_epoch_end(self, epoch, logs=None):
step_counter = self.checkpoint_manager._step_counter.numpy()
self.checkpoint_manager.save(checkpoint_number=step_counter)
def train(params, strategy, dataset=None):
"""Runs training."""
if not dataset:
dataset = input_pipeline.get_input_dataset(
FLAGS.train_file_pattern,
FLAGS.train_batch_size,
params,
is_training=True,
strategy=strategy)
with strategy.scope():
model = models.create_model(
FLAGS.model_type, params, init_checkpoint=FLAGS.init_checkpoint)
opt = optimizer.create_optimizer(params)
trainer = Trainer(model, params)
model.global_step = opt.iterations
trainer.compile(
optimizer=opt,
experimental_steps_per_execution=FLAGS.steps_per_loop)
summary_dir = os.path.join(FLAGS.model_dir, "summaries")
summary_callback = tf.keras.callbacks.TensorBoard(
summary_dir, update_freq=max(100, FLAGS.steps_per_loop))
checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=FLAGS.model_dir,
max_to_keep=10,
step_counter=model.global_step,
checkpoint_interval=FLAGS.checkpoint_interval)
if checkpoint_manager.restore_or_initialize():
logging.info("Training restored from the checkpoints in: %s",
FLAGS.model_dir)
checkpoint_callback = SimpleCheckpoint(checkpoint_manager)
# Trains the model.
steps_per_epoch = min(FLAGS.train_steps, FLAGS.checkpoint_interval)
epochs = FLAGS.train_steps // steps_per_epoch
trainer.fit(
x=dataset,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
callbacks=[summary_callback, checkpoint_callback],
verbose=2)
def run():
"""Runs NHNet using Keras APIs."""
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, tpu_address=FLAGS.tpu)
if strategy:
logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync)
params = models.get_model_params(FLAGS.model_type)
params = params_dict.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.override(
{
"len_title":
FLAGS.len_title,
"len_passage":
FLAGS.len_passage,
"num_hidden_layers":
FLAGS.num_encoder_layers,
"num_decoder_layers":
FLAGS.num_decoder_layers,
"passage_list":
[chr(ord("b") + i) for i in range(FLAGS.num_nhnet_articles)],
},
is_strict=False)
stats = {}
if "train" in FLAGS.mode:
train(params, strategy)
if "eval" in FLAGS.mode:
timeout = 0 if FLAGS.mode == "train_and_eval" else 3000
# Uses padded decoding for TPU. Always uses cache.
padded_decode = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
params.override({
"padded_decode": padded_decode,
}, is_strict=False)
stats = evaluation.continuous_eval(
strategy,
params,
model_type=FLAGS.model_type,
eval_file_pattern=FLAGS.eval_file_pattern,
batch_size=FLAGS.eval_batch_size,
eval_steps=FLAGS.eval_steps,
model_dir=FLAGS.model_dir,
timeout=timeout)
return stats
def main(_):
stats = run()
if stats:
logging.info("Stats:\n%s", stats)
if __name__ == "__main__":
define_flags()
app.run(main)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.nhnet.trainer."""
import os
from absl import flags
from absl.testing import parameterized
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
# pylint: enable=g-direct-tensorflow-import
from official.nlp.nhnet import trainer
from official.nlp.nhnet import utils
FLAGS = flags.FLAGS
trainer.define_flags()
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.tpu_strategy,
],
mode="eager",
)
def get_trivial_data(config) -> tf.data.Dataset:
"""Gets trivial data in the ImageNet size."""
batch_size, num_docs = 2, len(config.passage_list),
len_passage = config.len_passage
len_title = config.len_title
def generate_data(_) -> tf.data.Dataset:
fake_ids = tf.zeros((num_docs, len_passage), dtype=tf.int32)
title = tf.zeros((len_title), dtype=tf.int32)
return dict(
input_ids=fake_ids,
input_mask=fake_ids,
segment_ids=fake_ids,
target_ids=title)
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(generate_data,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(buffer_size=1).batch(batch_size)
return dataset
class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainerTest, self).setUp()
self._config = utils.get_test_params()
self._config.override(
{
"vocab_size": 49911,
"max_position_embeddings": 200,
"len_title": 15,
"len_passage": 20,
"beam_size": 5,
"alpha": 0.6,
"learning_rate": 0.0,
"learning_rate_warmup_steps": 0,
"multi_channel_cross_attention": True,
"passage_list": ["a", "b"],
},
is_strict=False)
@combinations.generate(all_strategy_combinations())
def test_train(self, distribution):
FLAGS.train_steps = 10
FLAGS.checkpoint_interval = 5
FLAGS.model_dir = self.get_temp_dir()
FLAGS.model_type = "nhnet"
trainer.train(self._config, distribution, get_trivial_data(self._config))
self.assertLen(
tf.io.gfile.glob(os.path.join(FLAGS.model_dir, "ckpt*.index")), 2)
if __name__ == "__main__":
tf.test.main()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility helpers for Bert2Bert."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
from absl import logging
import tensorflow as tf
from typing import Optional, Text
from official.modeling.hyperparams import params_dict
from official.nlp.bert import configs
from official.nlp.nhnet import configs as nhnet_configs
def get_bert_config_from_params(
params: params_dict.ParamsDict) -> configs.BertConfig:
"""Converts a BertConfig to ParamsDict."""
return configs.BertConfig.from_dict(params.as_dict())
def get_test_params(cls=nhnet_configs.BERT2BERTConfig):
return cls.from_args(**nhnet_configs.UNITTEST_CONFIG)
# pylint: disable=protected-access
def encoder_common_layers(transformer_block):
return [
transformer_block._attention_layer,
transformer_block._attention_output_dense,
transformer_block._attention_layer_norm,
transformer_block._intermediate_dense, transformer_block._output_dense,
transformer_block._output_layer_norm
]
# pylint: enable=protected-access
def initialize_bert2bert_from_pretrained_bert(
bert_encoder: tf.keras.layers.Layer,
bert_decoder: tf.keras.layers.Layer,
init_checkpoint: Optional[Text] = None) -> None:
"""Helper function to initialze Bert2Bert from Bert pretrained checkpoint."""
ckpt = tf.train.Checkpoint(model=bert_encoder)
logging.info(
"Checkpoint file %s found and restoring from "
"initial checkpoint for core model.", init_checkpoint)
status = ckpt.restore(init_checkpoint)
# Expects the bert model is a subset of checkpoint as pooling layer is
# not used.
status.assert_existing_objects_matched()
logging.info("Loading from checkpoint file completed.")
# Saves a checkpoint with transformer layers.
encoder_layers = []
for transformer_block in bert_encoder.transformer_layers:
encoder_layers.extend(encoder_common_layers(transformer_block))
# Restores from the checkpoint with encoder layers.
decoder_layers_to_initialize = []
for decoder_block in bert_decoder.decoder.layers:
decoder_layers_to_initialize.extend(
decoder_block.common_layers_with_encoder())
if len(decoder_layers_to_initialize) != len(encoder_layers):
raise ValueError(
"Source encoder layers with %d objects does not match destination "
"decoder layers with %d objects." %
(len(decoder_layers_to_initialize), len(encoder_layers)))
for dest_layer, source_layer in zip(decoder_layers_to_initialize,
encoder_layers):
try:
dest_layer.set_weights(source_layer.get_weights())
except ValueError as e:
logging.error(
"dest_layer: %s failed to set weights from "
"source_layer: %s as %s", dest_layer.name, source_layer.name, str(e))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment