Unverified Commit f16a7b5b authored by vedanshu's avatar vedanshu Committed by GitHub
Browse files

Merge pull request #1 from tensorflow/master

new pull
parents 8e9296ff 8f58f396
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the progressive trainer."""
# pylint: disable=g-direct-tensorflow-import
import os
from absl.testing import parameterized
import orbit
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import config_definitions as cfg
from official.modeling import optimization
from official.modeling.progressive import policies
from official.modeling.progressive import trainer as trainer_lib
from official.nlp.configs import bert
from official.utils.testing import mock_task
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],)
def get_exp_config():
return cfg.ExperimentConfig(
task=cfg.TaskConfig(
model=bert.PretrainerConfig()),
trainer=trainer_lib.ProgressiveTrainerConfig(
export_checkpoint=True,
export_checkpoint_interval=1,
export_only_final_stage_ckpt=False))
class TestPolicy(policies.ProgressivePolicy, mock_task.MockTask):
"""Just for testing purposes."""
def __init__(self, strategy, task_config, change_train_dataset=True):
self._strategy = strategy
self._change_train_dataset = change_train_dataset
self._my_train_dataset = None
mock_task.MockTask.__init__(self, params=task_config, logging_dir=None)
policies.ProgressivePolicy.__init__(self)
def num_stages(self) -> int:
return 2
def num_steps(self, stage_id: int) -> int:
return 2 if stage_id == 0 else 4
def get_model(self,
stage_id: int,
old_model: tf.keras.Model) -> tf.keras.Model:
del stage_id, old_model
return self.build_model()
def get_optimizer(self, stage_id: int) -> tf.keras.optimizers.Optimizer:
optimizer_type = 'sgd' if stage_id == 0 else 'adamw'
optimizer_config = cfg.OptimizationConfig({
'optimizer': {'type': optimizer_type},
'learning_rate': {'type': 'constant'}})
opt_factory = optimization.OptimizerFactory(optimizer_config)
return opt_factory.build_optimizer(opt_factory.build_learning_rate())
def get_train_dataset(self, stage_id: int) -> tf.data.Dataset:
if not self._change_train_dataset and self._my_train_dataset:
return self._my_train_dataset
if self._strategy:
self._my_train_dataset = orbit.utils.make_distributed_dataset(
self._strategy,
self._build_inputs,
stage_id)
else:
self._my_train_dataset = self._build_inputs(stage_id)
return self._my_train_dataset
def get_eval_dataset(self, stage_id: int) -> tf.data.Dataset:
if self._strategy:
return orbit.utils.make_distributed_dataset(
self._strategy,
self._build_inputs,
stage_id)
return self._build_inputs(stage_id)
def _build_inputs(self, stage_id):
def dummy_data(_):
batch_size = 2 if stage_id == 0 else 1
x = tf.zeros(shape=(batch_size, 2), dtype=tf.float32)
label = tf.zeros(shape=(batch_size, 1), dtype=tf.float32)
return x, label
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
return dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainerTest, self).setUp()
self._config = get_exp_config()
def create_test_trainer(self, distribution, model_dir, change_train_dataset):
trainer = trainer_lib.ProgressiveTrainer(
self._config,
prog_task=TestPolicy(
distribution, self._config.task, change_train_dataset),
ckpt_dir=model_dir)
return trainer
@combinations.generate(all_strategy_combinations())
def test_checkpointing(self, distribution):
model_dir = self.get_temp_dir()
ckpt_file = os.path.join(model_dir, 'ckpt')
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, True)
self.assertFalse(trainer._task.is_last_stage)
trainer.train(tf.convert_to_tensor(4, dtype=tf.int32))
self.assertTrue(trainer._task.is_last_stage)
trainer.checkpoint.save(ckpt_file)
trainer = self.create_test_trainer(distribution, model_dir, True)
self.assertFalse(trainer._task.is_last_stage)
trainer.checkpoint.restore(ckpt_file + '-1')
self.assertTrue(trainer._task.is_last_stage)
@combinations.generate(all_strategy_combinations())
def test_train_dataset(self, distribution):
model_dir = self.get_temp_dir()
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, True)
# Using dataset of stage == 0
train_iter = tf.nest.map_structure(iter, trainer.train_dataset)
train_data = train_iter.next()[0]
if distribution.num_replicas_in_sync > 1:
train_data = train_data.values[0]
self.assertEqual(train_data.shape[0], 2)
trainer.train(tf.convert_to_tensor(4, dtype=tf.int32))
# Using dataset of stage == 1
train_iter = tf.nest.map_structure(iter, trainer.train_dataset)
train_data = train_iter.next()[0]
if distribution.num_replicas_in_sync > 1:
train_data = train_data.values[0]
self.assertEqual(train_data.shape[0], 1)
with self.assertRaises(SyntaxError):
trainer.train_dataset = None
@combinations.generate(all_strategy_combinations())
def test_train_dataset_no_switch(self, distribution):
model_dir = self.get_temp_dir()
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, False)
trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
# _train_iter is not reset since the dataset is not changed.
self.assertIsNotNone(trainer._train_iter)
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, True)
trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
# _train_iter is reset since the dataset changed.
self.assertIsNone(trainer._train_iter)
class TrainerWithMaskedLMTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainerWithMaskedLMTaskTest, self).setUp()
self._config = get_exp_config()
def create_test_trainer(self, distribution):
trainer = trainer_lib.ProgressiveTrainer(
self._config,
prog_task=TestPolicy(distribution, self._config.task),
ckpt_dir=self.get_temp_dir())
return trainer
@combinations.generate(all_strategy_combinations())
def test_trainer_train(self, distribution):
with distribution.scope():
trainer = self.create_test_trainer(distribution)
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
@combinations.generate(all_strategy_combinations())
def test_trainer_validate(self, distribution):
with distribution.scope():
trainer = self.create_test_trainer(distribution)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('validation_loss', logs)
self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
@combinations.generate(
combinations.combine(
mixed_precision_dtype=['float32', 'bfloat16', 'float16'],
loss_scale=[None, 'dynamic', 128, 256],
))
def test_configure_optimizer(self, mixed_precision_dtype, loss_scale):
config = cfg.ExperimentConfig(
task=cfg.TaskConfig(
model=bert.PretrainerConfig()),
runtime=cfg.RuntimeConfig(
mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale),
trainer=trainer_lib.ProgressiveTrainerConfig(
export_checkpoint=True,
export_checkpoint_interval=1,
export_only_final_stage_ckpt=False))
task = TestPolicy(None, config.task)
trainer = trainer_lib.ProgressiveTrainer(config, task, self.get_temp_dir())
if mixed_precision_dtype != 'float16':
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
elif mixed_precision_dtype == 'float16' and loss_scale is None:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Util classes and functions."""
from absl import logging
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.training.tracking import tracking
class VolatileTrackable(tracking.AutoTrackable):
"""A util class to keep Trackables that might change instances."""
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def reassign_trackable(self, **kwargs):
for k, v in kwargs.items():
delattr(self, k) # untrack this object
setattr(self, k, v) # track the new object
class CheckpointWithHooks(tf.train.Checkpoint):
"""Same as tf.train.Checkpoint but supports hooks.
In progressive training, use this class instead of tf.train.Checkpoint.
Since the network architecture changes during progressive training, we need to
prepare something (like switch to the correct architecture) before loading the
checkpoint. This class supports a hook that will be executed before checkpoint
loading.
"""
def __init__(self, before_load_hook, **kwargs):
self._before_load_hook = before_load_hook
super(CheckpointWithHooks, self).__init__(**kwargs)
# override
def read(self, save_path, options=None):
self._before_load_hook(save_path)
logging.info('Ran before_load_hook.')
super(CheckpointWithHooks, self).read(save_path=save_path, options=options)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,12 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common TF utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Common TF utilities."""
import six
import tensorflow as tf
......@@ -29,8 +25,7 @@ from official.modeling import activations
None,
"tf.keras.layers.Layer supports multiple positional args and kwargs as "
"input tensors. pack/unpack inputs to override __call__ is no longer "
"needed."
)
"needed.")
def pack_inputs(inputs):
"""Pack a list of `inputs` tensors to a tuple.
......@@ -55,8 +50,7 @@ def pack_inputs(inputs):
None,
"tf.keras.layers.Layer supports multiple positional args and kwargs as "
"input tensors. pack/unpack inputs to override __call__ is no longer "
"needed."
)
"needed.")
def unpack_inputs(inputs):
"""unpack a tuple of `inputs` tensors to a tuple.
......@@ -88,27 +82,44 @@ def is_special_none_tensor(tensor):
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
def get_activation(identifier):
def get_activation(identifier, use_keras_layer=False):
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
It checks string first and if it is one of customized activation not in TF,
the corresponding activation will be returned. For non-customized activation
names and callable identifiers, always fallback to tf.keras.activations.get.
Prefers using keras layers when use_keras_layer=True. Now it only supports
'relu', 'linear', 'identity', 'swish'.
Args:
identifier: String name of the activation function or callable.
use_keras_layer: If True, use keras layer if identifier is allow-listed.
Returns:
A Python function corresponding to the activation function.
A Python function corresponding to the activation function or a keras
activation layer when use_keras_layer=True.
"""
if isinstance(identifier, six.string_types):
identifier = str(identifier).lower()
if use_keras_layer:
keras_layer_allowlist = {
"relu": "relu",
"linear": "linear",
"identity": "linear",
"swish": "swish",
"relu6": tf.nn.relu6,
}
if identifier in keras_layer_allowlist:
return tf.keras.layers.Activation(keras_layer_allowlist[identifier])
name_to_fn = {
"gelu": activations.gelu,
"simple_swish": activations.simple_swish,
"hard_swish": activations.hard_swish,
"relu6": activations.relu6,
"hard_sigmoid": activations.hard_sigmoid,
"identity": activations.identity,
}
identifier = str(identifier).lower()
if identifier in name_to_fn:
return tf.keras.activations.get(name_to_fn[identifier])
return tf.keras.activations.get(identifier)
......
......@@ -20,8 +20,11 @@ to experiment new research ideas.
We provide modeling library to allow users to train custom models for new
research ideas. Detailed intructions can be found in READMEs in each folder.
* [modeling/](modeling): modeling library that provides building blocks (e.g., Layers, Networks, and Models) that can be assembled into transformer-based achitectures .
* [data/](data): binaries and utils for input preprocessing, tokenization, etc.
* [modeling/](modeling): modeling library that provides building blocks
(e.g.,Layers, Networks, and Models) that can be assembled into
transformer-based achitectures .
* [data/](data): binaries and utils for input preprocessing, tokenization,
etc.
### State-of-the-Art models and examples
......@@ -29,9 +32,31 @@ We provide SoTA model implementations, pre-trained models, training and
evaluation examples, and command lines. Detail instructions can be found in the
READMEs for specific papers.
1. [BERT](bert): [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Devlin et al., 2018
2. [ALBERT](albert): [A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) by Lan et al., 2019
3. [XLNet](xlnet): [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Yang et al., 2019
4. [Transformer for translation](transformer): [Attention Is All You Need](https://arxiv.org/abs/1706.03762) by Vaswani et al., 2017
5. [NHNet](nhnet): [Generating Representative Headlines for News Stories](https://arxiv.org/abs/2001.09386) by Gu et al, 2020
1. [BERT](bert): [BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding](https://arxiv.org/abs/1810.04805) by Devlin et al.,
2018
2. [ALBERT](albert):
[A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942)
by Lan et al., 2019
3. [XLNet](xlnet):
[XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237)
by Yang et al., 2019
4. [Transformer for translation](transformer):
[Attention Is All You Need](https://arxiv.org/abs/1706.03762) by Vaswani et
al., 2017
5. [NHNet](nhnet):
[Generating Representative Headlines for News Stories](https://arxiv.org/abs/2001.09386)
by Gu et al, 2020
### Common Training Driver
We provide a single common driver [train.py](train.py) to train above SoTA
models on popluar tasks. Please see [docs/train.md](docs/train.md) for
more details.
### Pre-trained models with checkpoints and TF-Hub
We provide a large collection of baselines and checkpoints for NLP pre-trained
models. Please see [docs/pretrained_models.md](docs/pretrained_models.md) for
more details.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -148,7 +148,7 @@ python ../data/create_finetuning_data.py \
--meta_data_file_path=${OUTPUT_DIR}/${TASK_NAME}_meta_data \
--fine_tuning_task_type=classification --max_seq_length=128 \
--classification_task_name=${TASK_NAME} \
--tokenizer_impl=sentence_piece
--tokenization=SentencePiece
```
* SQUAD
......@@ -177,7 +177,7 @@ python ../data/create_finetuning_data.py \
--train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data \
--fine_tuning_task_type=squad --max_seq_length=384 \
--tokenizer_impl=sentence_piece
--tokenization=SentencePiece
```
## Fine-tuning with ALBERT
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,12 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The ALBERT configurations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""The ALBERT configurations."""
import six
......@@ -26,10 +22,7 @@ from official.nlp.bert import configs
class AlbertConfig(configs.BertConfig):
"""Configuration for `ALBERT`."""
def __init__(self,
num_hidden_groups=1,
inner_group_num=1,
**kwargs):
def __init__(self, num_hidden_groups=1, inner_group_num=1, **kwargs):
"""Constructs AlbertConfig.
Args:
......@@ -43,8 +36,7 @@ class AlbertConfig(configs.BertConfig):
super(AlbertConfig, self).__init__(**kwargs)
# TODO(chendouble): 'inner_group_num' and 'num_hidden_groups' are always 1
# in the released ALBERT. Support other values in AlbertTransformerEncoder
# if needed.
# in the released ALBERT. Support other values in AlbertEncoder if needed.
if inner_group_num != 1 or num_hidden_groups != 1:
raise ValueError("We only support 'inner_group_num' and "
"'num_hidden_groups' as 1.")
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A script to export the ALBERT core model as a TF-Hub SavedModel."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
from absl import app
from absl import flags
import tensorflow as tf
from typing import Text
from official.nlp.albert import configs
from official.nlp.bert import bert_models
FLAGS = flags.FLAGS
flags.DEFINE_string("albert_config_file", None,
"Albert configuration file to define core albert layers.")
flags.DEFINE_string("model_checkpoint_path", None,
"File path to TF model checkpoint.")
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
flags.DEFINE_string(
"sp_model_file", None,
"The sentence piece model file that the ALBERT model was trained on.")
def create_albert_model(
albert_config: configs.AlbertConfig) -> tf.keras.Model:
"""Creates an ALBERT keras core model from ALBERT configuration.
Args:
albert_config: An `AlbertConfig` to create the core model.
Returns:
A keras model.
"""
# Adds input layers just as placeholders.
input_word_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_word_ids")
input_mask = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_mask")
input_type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_type_ids")
transformer_encoder = bert_models.get_transformer_encoder(
albert_config, sequence_length=None)
sequence_output, pooled_output = transformer_encoder(
[input_word_ids, input_mask, input_type_ids])
# To keep consistent with legacy hub modules, the outputs are
# "pooled_output" and "sequence_output".
return tf.keras.Model(
inputs=[input_word_ids, input_mask, input_type_ids],
outputs=[pooled_output, sequence_output]), transformer_encoder
def export_albert_tfhub(albert_config: configs.AlbertConfig,
model_checkpoint_path: Text, hub_destination: Text,
sp_model_file: Text):
"""Restores a tf.keras.Model and saves for TF-Hub."""
core_model, encoder = create_albert_model(albert_config)
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.restore(model_checkpoint_path).assert_consumed()
core_model.sp_model_file = tf.saved_model.Asset(sp_model_file)
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
def main(_):
albert_config = configs.AlbertConfig.from_json_file(
FLAGS.albert_config_file)
export_albert_tfhub(albert_config, FLAGS.model_checkpoint_path,
FLAGS.export_path, FLAGS.sp_model_file)
if __name__ == "__main__":
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests official.nlp.albert.export_albert_tfhub."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from official.nlp.albert import configs
from official.nlp.albert import export_albert_tfhub
class ExportAlbertTfhubTest(tf.test.TestCase):
def test_export_albert_tfhub(self):
# Exports a savedmodel for TF-Hub
albert_config = configs.AlbertConfig(
vocab_size=100,
embedding_size=8,
hidden_size=16,
intermediate_size=32,
max_position_embeddings=128,
num_attention_heads=2,
num_hidden_layers=1)
bert_model, encoder = export_albert_tfhub.create_albert_model(albert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
sp_model_file = os.path.join(self.get_temp_dir(), "sp_tokenizer.model")
with tf.io.gfile.GFile(sp_model_file, "w") as f:
f.write("dummy content")
hub_destination = os.path.join(self.get_temp_dir(), "hub")
export_albert_tfhub.export_albert_tfhub(
albert_config,
model_checkpoint_path,
hub_destination,
sp_model_file=sp_model_file)
# Restores a hub KerasLayer.
hub_layer = hub.KerasLayer(hub_destination, trainable=True)
if hasattr(hub_layer, "resolved_object"):
with tf.io.gfile.GFile(
hub_layer.resolved_object.sp_model_file.asset_path.numpy()) as f:
self.assertEqual("dummy content", f.read())
# Checks the hub KerasLayer.
for source_weight, hub_weight in zip(bert_model.trainable_weights,
hub_layer.trainable_weights):
self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
dummy_ids = np.zeros((2, 10), dtype=np.int32)
hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
# The outputs of hub module are "pooled_output" and "sequence_output",
# while the outputs of encoder is in reversed order, i.e.,
# "sequence_output" and "pooled_output".
encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids]))
self.assertEqual(hub_outputs[0].shape, (2, 16))
self.assertEqual(hub_outputs[1].shape, (2, 10, 16))
for source_output, hub_output, encoder_output in zip(
source_outputs, hub_outputs, encoder_outputs):
self.assertAllClose(source_output.numpy(), hub_output.numpy())
self.assertAllClose(source_output.numpy(), encoder_output.numpy())
if __name__ == "__main__":
tf.test.main()
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,25 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ALBERT classification finetuning runner in tf2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""ALBERT classification finetuning runner in tf2.x."""
import json
import os
# Import libraries
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.common import distribute_utils
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import bert_models
from official.nlp.bert import run_classifier as run_classifier_bert
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS
......@@ -76,7 +71,7 @@ def main(_):
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
strategy = distribution_utils.get_distribution_strategy(
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu)
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,27 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run ALBERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Run ALBERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
import json
import os
import time
# Import libraries
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.common import distribute_utils
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization
from official.nlp.data import squad_lib_sp
from official.utils.misc import distribution_utils
flags.DEFINE_string(
'sp_model_file', None,
......@@ -103,9 +99,8 @@ def main(_):
# Configures cluster spec for multi-worker distribution strategy.
if FLAGS.num_gpus > 0:
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
strategy = distribution_utils.get_distribution_strategy(
_ = distribute_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index)
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg,
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,25 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A converter from a tf1 ALBERT encoder checkpoint to a tf2 encoder checkpoint.
The conversion will yield an object-oriented checkpoint that can be used
to restore a AlbertTransformerEncoder object.
to restore an AlbertEncoder object.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
import tensorflow as tf
from official.modeling import activations
from official.modeling import tf_utils
from official.nlp.albert import configs
from official.nlp.bert import tf1_checkpoint_converter_lib
from official.nlp.modeling import models
from official.nlp.modeling import networks
FLAGS = flags.FLAGS
......@@ -42,6 +39,14 @@ flags.DEFINE_string(
"BertModel, with no task heads.)")
flags.DEFINE_string("converted_checkpoint_path", None,
"Name for the created object-based V2 checkpoint.")
flags.DEFINE_string("checkpoint_model_name", "encoder",
"The name of the model when saving the checkpoint, i.e., "
"the checkpoint will be saved using: "
"tf.train.Checkpoint(FLAGS.checkpoint_model_name=model).")
flags.DEFINE_enum(
"converted_model", "encoder", ["encoder", "pretrainer"],
"Whether to convert the checkpoint to a `AlbertEncoder` model or a "
"`BertPretrainerV2` model (with mlm but without classification heads).")
ALBERT_NAME_REPLACEMENTS = (
......@@ -55,11 +60,12 @@ ALBERT_NAME_REPLACEMENTS = (
("group_0/inner_group_0/", ""),
("attention_1/self", "self_attention"),
("attention_1/output/dense", "self_attention/attention_output"),
("LayerNorm/", "self_attention_layer_norm/"),
("transformer/LayerNorm/", "transformer/self_attention_layer_norm/"),
("ffn_1/intermediate/dense", "intermediate"),
("ffn_1/intermediate/output/dense", "output"),
("LayerNorm_1/", "output_layer_norm/"),
("transformer/LayerNorm_1/", "transformer/output_layer_norm/"),
("pooler/dense", "pooler_transform"),
("cls/predictions", "bert/cls/predictions"),
("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
("cls/seq_relationship/output_weights",
......@@ -68,32 +74,54 @@ ALBERT_NAME_REPLACEMENTS = (
def _create_albert_model(cfg):
"""Creates a BERT keras core model from BERT configuration.
"""Creates an ALBERT keras core model from BERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
cfg: A `AlbertConfig` to create the core model.
Returns:
A keras model.
"""
albert_encoder = networks.AlbertTransformerEncoder(
albert_encoder = networks.AlbertEncoder(
vocab_size=cfg.vocab_size,
hidden_size=cfg.hidden_size,
embedding_width=cfg.embedding_size,
num_layers=cfg.num_hidden_layers,
num_attention_heads=cfg.num_attention_heads,
intermediate_size=cfg.intermediate_size,
activation=activations.gelu,
activation=tf_utils.get_activation(cfg.hidden_act),
dropout_rate=cfg.hidden_dropout_prob,
attention_dropout_rate=cfg.attention_probs_dropout_prob,
sequence_length=cfg.max_position_embeddings,
max_sequence_length=cfg.max_position_embeddings,
type_vocab_size=cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range))
return albert_encoder
def convert_checkpoint(bert_config, output_path, v1_checkpoint):
def _create_pretrainer_model(cfg):
"""Creates a pretrainer with AlbertEncoder from ALBERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
Returns:
A BertPretrainerV2 model.
"""
albert_encoder = _create_albert_model(cfg)
pretrainer = models.BertPretrainerV2(
encoder_network=albert_encoder,
mlm_activation=tf_utils.get_activation(cfg.hidden_act),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range))
# Makes sure masked_lm layer's variables in pretrainer are created.
_ = pretrainer(pretrainer.inputs)
return pretrainer
def convert_checkpoint(bert_config, output_path, v1_checkpoint,
checkpoint_model_name,
converted_model="encoder"):
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
output_dir, _ = os.path.split(output_path)
......@@ -109,9 +137,16 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
exclude_patterns=["adam", "Adam"])
# Create a V2 checkpoint from the temporary checkpoint.
model = _create_albert_model(bert_config)
if converted_model == "encoder":
model = _create_albert_model(bert_config)
elif converted_model == "pretrainer":
model = _create_pretrainer_model(bert_config)
else:
raise ValueError("Unsupported converted_model: %s" % converted_model)
tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint,
output_path)
output_path,
checkpoint_model_name)
# Clean up the temporary checkpoint, if it exists.
try:
......@@ -124,8 +159,12 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
def main(_):
output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert
checkpoint_model_name = FLAGS.checkpoint_model_name
converted_model = FLAGS.converted_model
albert_config = configs.AlbertConfig.from_json_file(FLAGS.albert_config_file)
convert_checkpoint(albert_config, output_path, v1_checkpoint)
convert_checkpoint(albert_config, output_path, v1_checkpoint,
checkpoint_model_name,
converted_model=converted_model)
if __name__ == "__main__":
......
# BERT (Bidirectional Encoder Representations from Transformers)
**WARNING**: We are on the way to deprecate most of the code in this directory.
Please see
[this link](https://github.com/tensorflow/models/blob/master/official/nlp/docs/train.md)
for the new tutorial and use the new code in `nlp/modeling`. This README is
still correct for this legacy implementation.
The academic paper which describes BERT in detail and provides full results on a
number of tasks can be found here: https://arxiv.org/abs/1810.04805.
......@@ -46,6 +52,8 @@ The new checkpoints are:**
12-layer, 768-hidden, 12-heads , 110M parameters
* **[`BERT-Large, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-24_H-1024_A-16.tar.gz)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Multilingual Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/multi_cased_L-12_H-768_A-12.tar.gz)**:
104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
We recommend to host checkpoints on Google Cloud storage buckets when you use
Cloud GPU/TPU.
......@@ -70,21 +78,21 @@ Checkpoints featuring native serialized Keras models
Pretrained tf.hub modules in TF 2.x SavedModel format can be found in the
following links:
* **[`BERT-Large, Uncased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/1)**:
* **[`BERT-Large, Uncased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Large, Cased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_cased_L-24_H-1024_A-16/1)**:
* **[`BERT-Large, Cased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_cased_L-24_H-1024_A-16/)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1)**:
* **[`BERT-Base, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/)**:
12-layer, 768-hidden, 12-heads, 110M parameters
* **[`BERT-Large, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/1)**:
* **[`BERT-Large, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/1)**:
* **[`BERT-Base, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/)**:
12-layer, 768-hidden, 12-heads , 110M parameters
* **[`BERT-Large, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-24_H-1024_A-16/1)**:
* **[`BERT-Large, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-24_H-1024_A-16/)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Multilingual Cased`](https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/1)**:
* **[`BERT-Base, Multilingual Cased`](https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/)**:
104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
* **[`BERT-Base, Chinese`](https://tfhub.dev/tensorflow/bert_zh_L-12_H-768_A-12/1)**:
* **[`BERT-Base, Chinese`](https://tfhub.dev/tensorflow/bert_zh_L-12_H-768_A-12/)**:
Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads,
110M parameters
......@@ -123,6 +131,23 @@ which is essentially branched from [BERT research repo](https://github.com/googl
to get processed pre-training data and it adapts to TF2 symbols and python3
compatibility.
Running the pre-training script requires an input and output directory, as well as a vocab file. Note that max_seq_length will need to match the sequence length parameter you specify when you run pre-training.
Example shell script to call create_pretraining_data.py
```
export WORKING_DIR='local disk or cloud location'
export BERT_DIR='local disk or cloud location'
python models/official/nlp/data/create_pretraining_data.py \
--input_file=$WORKING_DIR/input/input.txt \
--output_file=$WORKING_DIR/output/tf_examples.tfrecord \
--vocab_file=$BERT_DIR/wwm_uncased_L-24_H-1024_A-16/vocab.txt \
--do_lower_case=True \
--max_seq_length=512 \
--max_predictions_per_seq=76 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=5
```
### Fine-tuning
......@@ -184,6 +209,8 @@ python ../data/create_finetuning_data.py \
--fine_tuning_task_type=squad --max_seq_length=384
```
Note: To create fine-tuning data with SQUAD 2.0, you need to add flag `--version_2_with_negative=True`.
## Fine-tuning with BERT
### Cloud GPUs and TPUs
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -4,17 +4,17 @@ This tutorial shows you how to train the Bidirectional Encoder Representations f
## Set up Cloud Storage and Compute Engine VM
1. [Open a cloud shell window](https://console.cloud.google.com/?cloudshell=true&_ga=2.11844148.-1612541229.1552429951)
2. Create a variable for the project's name:
2. Create a variable for the project's id:
```
export PROJECT_NAME=your-project_name
export PROJECT_ID=your-project_id
```
3. Configure `gcloud` command-line tool to use the project where you want to create Cloud TPU.
```
gcloud config set project ${PROJECT_NAME}
gcloud config set project ${PROJECT_ID}
```
4. Create a Cloud Storage bucket using the following command:
```
gsutil mb -p ${PROJECT_NAME} -c standard -l europe-west4 -b on gs://your-bucket-name
gsutil mb -p ${PROJECT_ID} -c standard -l europe-west4 -b on gs://your-bucket-name
```
This Cloud Storage bucket stores the data you use to train your model and the training results.
5. Launch a Compute Engine VM and Cloud TPU using the ctpu up command.
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,12 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT models that are compatible with TF 2.0."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""BERT models that are compatible with TF 2.0."""
import gin
import tensorflow as tf
......@@ -104,29 +100,29 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
@gin.configurable
def get_transformer_encoder(bert_config,
sequence_length,
sequence_length=None,
transformer_encoder_cls=None,
output_range=None):
"""Gets a 'TransformerEncoder' object.
Args:
bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
sequence_length: Maximum sequence length of the training data.
sequence_length: [Deprecated].
transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
default BERT encoder implementation.
output_range: the sequence output range, [0, output_range). Default setting
is to return the entire sequence output.
Returns:
A networks.TransformerEncoder object.
A encoder object.
"""
del sequence_length
if transformer_encoder_cls is not None:
# TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
embedding_cfg = dict(
vocab_size=bert_config.vocab_size,
type_vocab_size=bert_config.type_vocab_size,
hidden_size=bert_config.hidden_size,
seq_length=sequence_length,
max_seq_length=bert_config.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
......@@ -161,18 +157,17 @@ def get_transformer_encoder(bert_config,
activation=tf_utils.get_activation(bert_config.hidden_act),
dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
sequence_length=sequence_length,
max_sequence_length=bert_config.max_position_embeddings,
type_vocab_size=bert_config.type_vocab_size,
embedding_width=bert_config.embedding_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range))
if isinstance(bert_config, albert_configs.AlbertConfig):
return networks.AlbertTransformerEncoder(**kwargs)
return networks.AlbertEncoder(**kwargs)
else:
assert isinstance(bert_config, configs.BertConfig)
kwargs['output_range'] = output_range
return networks.TransformerEncoder(**kwargs)
return networks.BertEncoder(**kwargs)
def pretrain_model(bert_config,
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,10 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
......@@ -48,16 +44,16 @@ class BertModelsTest(tf.test.TestCase):
initializer=None,
use_next_sentence_label=True)
self.assertIsInstance(model, tf.keras.Model)
self.assertIsInstance(encoder, networks.TransformerEncoder)
self.assertIsInstance(encoder, networks.BertEncoder)
# model has one scalar output: loss value.
self.assertEqual(model.output.shape.as_list(), [None,])
self.assertEqual(model.output.shape.as_list(), [
None,
])
# Expect two output from encoder: sequence and classification output.
self.assertIsInstance(encoder.output, list)
self.assertLen(encoder.output, 2)
# shape should be [batch size, seq_length, hidden_size]
self.assertEqual(encoder.output[0].shape.as_list(), [None, 5, 16])
# shape should be [batch size, hidden_size]
self.assertEqual(encoder.output[1].shape.as_list(), [None, 16])
......@@ -74,16 +70,12 @@ class BertModelsTest(tf.test.TestCase):
# Expect two output from model: start positions and end positions
self.assertIsInstance(model.output, list)
self.assertLen(model.output, 2)
# shape should be [batch size, seq_length]
self.assertEqual(model.output[0].shape.as_list(), [None, 5])
# shape should be [batch size, seq_length]
self.assertEqual(model.output[1].shape.as_list(), [None, 5])
# Expect two output from core_model: sequence and classification output.
self.assertIsInstance(core_model.output, list)
self.assertLen(core_model.output, 2)
# shape should be [batch size, seq_length, hidden_size]
self.assertEqual(core_model.output[0].shape.as_list(), [None, 5, 16])
# shape should be [batch size, None, hidden_size]
self.assertEqual(core_model.output[0].shape.as_list(), [None, None, 16])
# shape should be [batch size, hidden_size]
self.assertEqual(core_model.output[1].shape.as_list(), [None, 16])
......@@ -104,8 +96,8 @@ class BertModelsTest(tf.test.TestCase):
# Expect two output from core_model: sequence and classification output.
self.assertIsInstance(core_model.output, list)
self.assertLen(core_model.output, 2)
# shape should be [batch size, 1, hidden_size]
self.assertEqual(core_model.output[0].shape.as_list(), [None, 1, 16])
# shape should be [batch size, None, hidden_size]
self.assertEqual(core_model.output[0].shape.as_list(), [None, None, 16])
# shape should be [batch size, hidden_size]
self.assertEqual(core_model.output[1].shape.as_list(), [None, 16])
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defining common flags used across all BERT models/applications."""
from absl import flags
......@@ -73,9 +73,22 @@ def define_common_bert_flags():
'If specified, init_checkpoint flag should not be used.')
flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.')
flags.DEFINE_string('sub_model_export_name', None,
'If set, `sub_model` checkpoints are exported into '
'FLAGS.model_dir/FLAGS.sub_model_export_name.')
flags.DEFINE_string(
'sub_model_export_name', None,
'If set, `sub_model` checkpoints are exported into '
'FLAGS.model_dir/FLAGS.sub_model_export_name.')
flags.DEFINE_bool('explicit_allreduce', False,
'True to use explicit allreduce instead of the implicit '
'allreduce in optimizer.apply_gradients(). If fp16 mixed '
'precision training is used, this also enables allreduce '
'gradients in fp16.')
flags.DEFINE_integer('allreduce_bytes_per_pack', 0,
'Number of bytes of a gradient pack for allreduce. '
'Should be positive integer, if set to 0, all '
'gradients are in one pack. Breaking gradient into '
'packs could enable overlap between allreduce and '
'backprop computation. This flag only takes effect '
'when explicit_allreduce is set to True.')
flags_core.define_log_steps()
......@@ -87,7 +100,6 @@ def define_common_bert_flags():
synthetic_data=False,
max_train_steps=False,
dtype=True,
dynamic_loss_scale=True,
loss_scale=True,
all_reduce_alg=True,
num_packs=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment