Commit 2565629c authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 299169021
parent d3d7f15f
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import gin
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
...@@ -85,16 +86,46 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -85,16 +86,46 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
return final_loss return final_loss
def get_transformer_encoder(bert_config, sequence_length): @gin.configurable
def get_transformer_encoder(bert_config,
sequence_length,
transformer_encoder_cls=None):
"""Gets a 'TransformerEncoder' object. """Gets a 'TransformerEncoder' object.
Args: Args:
bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object. bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
sequence_length: Maximum sequence length of the training data. sequence_length: Maximum sequence length of the training data.
transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
default BERT encoder implementation.
Returns: Returns:
A networks.TransformerEncoder object. A networks.TransformerEncoder object.
""" """
if transformer_encoder_cls is not None:
# TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
embedding_cfg = dict(
vocab_size=bert_config.vocab_size,
type_vocab_size=bert_config.type_vocab_size,
hidden_size=bert_config.hidden_size,
seq_length=sequence_length,
max_seq_length=bert_config.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
dropout_rate=bert_config.hidden_dropout_prob,
)
hidden_cfg = dict(
num_attention_heads=bert_config.num_attention_heads,
intermediate_size=bert_config.intermediate_size,
intermediate_activation=tf_utils.get_activation(bert_config.hidden_act),
dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
)
kwargs = dict(embedding_cfg=embedding_cfg, hidden_cfg=hidden_cfg,
num_hidden_instances=bert_config.num_hidden_layers,)
# Relies on gin configuration to define the Transformer encoder arguments.
return transformer_encoder_cls(**kwargs)
kwargs = dict( kwargs = dict(
vocab_size=bert_config.vocab_size, vocab_size=bert_config.vocab_size,
hidden_size=bert_config.hidden_size, hidden_size=bert_config.hidden_size,
......
...@@ -20,6 +20,14 @@ import tensorflow as tf ...@@ -20,6 +20,14 @@ import tensorflow as tf
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
def define_gin_flags():
"""Define common gin configurable flags."""
flags.DEFINE_multi_string('gin_file', None,
'List of paths to the config files.')
flags.DEFINE_multi_string(
'gin_param', None, 'Newline separated list of Gin parameter bindings.')
def define_common_bert_flags(): def define_common_bert_flags():
"""Define common flags for BERT tasks.""" """Define common flags for BERT tasks."""
flags_core.define_base( flags_core.define_base(
......
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import model_training_utils from official.modeling import model_training_utils
...@@ -49,6 +50,7 @@ flags.DEFINE_float('warmup_steps', 10000, ...@@ -49,6 +50,7 @@ flags.DEFINE_float('warmup_steps', 10000,
'Warmup steps for Adam weight decay optimizer.') 'Warmup steps for Adam weight decay optimizer.')
common_flags.define_common_bert_flags() common_flags.define_common_bert_flags()
common_flags.define_gin_flags()
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -158,7 +160,7 @@ def run_bert_pretrain(strategy): ...@@ -158,7 +160,7 @@ def run_bert_pretrain(strategy):
def main(_): def main(_):
# Users should always run this script under TF 2.x # Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.') assert tf.version.VERSION.startswith('2.')
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
if not FLAGS.model_dir: if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/' FLAGS.model_dir = '/tmp/bert20/'
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
......
...@@ -19,6 +19,7 @@ from __future__ import division ...@@ -19,6 +19,7 @@ from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
from __future__ import print_function from __future__ import print_function
import gin
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
...@@ -26,6 +27,7 @@ from official.nlp.modeling.layers import dense_einsum ...@@ -26,6 +27,7 @@ from official.nlp.modeling.layers import dense_einsum
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable
class TransformerScaffold(tf.keras.layers.Layer): class TransformerScaffold(tf.keras.layers.Layer):
"""Transformer scaffold layer. """Transformer scaffold layer.
......
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -20,6 +21,8 @@ from __future__ import division ...@@ -20,6 +21,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import inspect import inspect
import gin
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras.engine import network # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras.engine import network # pylint: disable=g-direct-tensorflow-import
...@@ -27,6 +30,7 @@ from official.nlp.modeling import layers ...@@ -27,6 +30,7 @@ from official.nlp.modeling import layers
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
@gin.configurable
class EncoderScaffold(network.Network): class EncoderScaffold(network.Network):
"""Bi-directional Transformer-based encoder network scaffold. """Bi-directional Transformer-based encoder network scaffold.
...@@ -96,7 +100,6 @@ class EncoderScaffold(network.Network): ...@@ -96,7 +100,6 @@ class EncoderScaffold(network.Network):
hidden_cls=layers.Transformer, hidden_cls=layers.Transformer,
hidden_cfg=None, hidden_cfg=None,
**kwargs): **kwargs):
print(embedding_cfg)
self._self_setattr_tracking = False self._self_setattr_tracking = False
self._hidden_cls = hidden_cls self._hidden_cls = hidden_cls
self._hidden_cfg = hidden_cfg self._hidden_cfg = hidden_cfg
...@@ -171,7 +174,8 @@ class EncoderScaffold(network.Network): ...@@ -171,7 +174,8 @@ class EncoderScaffold(network.Network):
for _ in range(num_hidden_instances): for _ in range(num_hidden_instances):
if inspect.isclass(hidden_cls): if inspect.isclass(hidden_cls):
layer = self._hidden_cls(**hidden_cfg) layer = self._hidden_cls(
**hidden_cfg) if hidden_cfg else self._hidden_cls()
else: else:
layer = self._hidden_cls layer = self._hidden_cls
data = layer([data, attention_mask]) data = layer([data, attention_mask])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment