Commit 2565629c authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 299169021
parent d3d7f15f
......@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gin
import tensorflow as tf
import tensorflow_hub as hub
......@@ -85,16 +86,46 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
return final_loss
def get_transformer_encoder(bert_config, sequence_length):
@gin.configurable
def get_transformer_encoder(bert_config,
sequence_length,
transformer_encoder_cls=None):
"""Gets a 'TransformerEncoder' object.
Args:
bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
sequence_length: Maximum sequence length of the training data.
transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
default BERT encoder implementation.
Returns:
A networks.TransformerEncoder object.
"""
if transformer_encoder_cls is not None:
# TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
embedding_cfg = dict(
vocab_size=bert_config.vocab_size,
type_vocab_size=bert_config.type_vocab_size,
hidden_size=bert_config.hidden_size,
seq_length=sequence_length,
max_seq_length=bert_config.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
dropout_rate=bert_config.hidden_dropout_prob,
)
hidden_cfg = dict(
num_attention_heads=bert_config.num_attention_heads,
intermediate_size=bert_config.intermediate_size,
intermediate_activation=tf_utils.get_activation(bert_config.hidden_act),
dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
)
kwargs = dict(embedding_cfg=embedding_cfg, hidden_cfg=hidden_cfg,
num_hidden_instances=bert_config.num_hidden_layers,)
# Relies on gin configuration to define the Transformer encoder arguments.
return transformer_encoder_cls(**kwargs)
kwargs = dict(
vocab_size=bert_config.vocab_size,
hidden_size=bert_config.hidden_size,
......
......@@ -20,6 +20,14 @@ import tensorflow as tf
from official.utils.flags import core as flags_core
def define_gin_flags():
"""Define common gin configurable flags."""
flags.DEFINE_multi_string('gin_file', None,
'List of paths to the config files.')
flags.DEFINE_multi_string(
'gin_param', None, 'Newline separated list of Gin parameter bindings.')
def define_common_bert_flags():
"""Define common flags for BERT tasks."""
flags_core.define_base(
......
......@@ -20,6 +20,7 @@ from __future__ import print_function
from absl import app
from absl import flags
from absl import logging
import gin
import tensorflow as tf
from official.modeling import model_training_utils
......@@ -49,6 +50,7 @@ flags.DEFINE_float('warmup_steps', 10000,
'Warmup steps for Adam weight decay optimizer.')
common_flags.define_common_bert_flags()
common_flags.define_gin_flags()
FLAGS = flags.FLAGS
......@@ -158,7 +160,7 @@ def run_bert_pretrain(strategy):
def main(_):
# Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
strategy = distribution_utils.get_distribution_strategy(
......
......@@ -19,6 +19,7 @@ from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import gin
import tensorflow as tf
from official.nlp.modeling.layers import attention
......@@ -26,6 +27,7 @@ from official.nlp.modeling.layers import dense_einsum
@tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable
class TransformerScaffold(tf.keras.layers.Layer):
"""Transformer scaffold layer.
......
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -20,6 +21,8 @@ from __future__ import division
from __future__ import print_function
import inspect
import gin
import tensorflow as tf
from tensorflow.python.keras.engine import network # pylint: disable=g-direct-tensorflow-import
......@@ -27,6 +30,7 @@ from official.nlp.modeling import layers
@tf.keras.utils.register_keras_serializable(package='Text')
@gin.configurable
class EncoderScaffold(network.Network):
"""Bi-directional Transformer-based encoder network scaffold.
......@@ -96,7 +100,6 @@ class EncoderScaffold(network.Network):
hidden_cls=layers.Transformer,
hidden_cfg=None,
**kwargs):
print(embedding_cfg)
self._self_setattr_tracking = False
self._hidden_cls = hidden_cls
self._hidden_cfg = hidden_cfg
......@@ -171,7 +174,8 @@ class EncoderScaffold(network.Network):
for _ in range(num_hidden_instances):
if inspect.isclass(hidden_cls):
layer = self._hidden_cls(**hidden_cfg)
layer = self._hidden_cls(
**hidden_cfg) if hidden_cfg else self._hidden_cls()
else:
layer = self._hidden_cls
data = layer([data, attention_mask])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment