Commit 102f267e authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 374236491
parent 34731381
......@@ -18,15 +18,15 @@ Includes configurations and factory methods.
"""
from typing import Optional
from absl import logging
import dataclasses
import gin
import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling import networks
from official.nlp.projects.bigbird import encoder as bigbird_encoder
from official.nlp.projects.bigbird import attention as bigbird_attention
@dataclasses.dataclass
......@@ -177,15 +177,6 @@ class EncoderConfig(hyperparams.OneOfConfig):
xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
ENCODER_CLS = {
"bert": networks.BertEncoder,
"mobilebert": networks.MobileBERTEncoder,
"albert": networks.AlbertEncoder,
"bigbird": bigbird_encoder.BigBirdEncoder,
"xlnet": networks.XLNetBase,
}
@gin.configurable
def build_encoder(config: EncoderConfig,
embedding_layer: Optional[tf.keras.layers.Layer] = None,
......@@ -205,13 +196,11 @@ def build_encoder(config: EncoderConfig,
Returns:
An encoder instance.
"""
encoder_type = config.type
encoder_cfg = config.get()
encoder_cls = encoder_cls or ENCODER_CLS[encoder_type]
logging.info("Encoder class: %s to build...", encoder_cls.__name__)
if bypass_config:
return encoder_cls()
if encoder_cls.__name__ == "EncoderScaffold":
encoder_type = config.type
encoder_cfg = config.get()
if encoder_cls and encoder_cls.__name__ == "EncoderScaffold":
embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size,
......@@ -243,7 +232,7 @@ def build_encoder(config: EncoderConfig,
return encoder_cls(**kwargs)
if encoder_type == "mobilebert":
return encoder_cls(
return networks.MobileBERTEncoder(
word_vocab_size=encoder_cfg.word_vocab_size,
word_embed_size=encoder_cfg.word_embed_size,
type_vocab_size=encoder_cfg.type_vocab_size,
......@@ -265,7 +254,7 @@ def build_encoder(config: EncoderConfig,
input_mask_dtype=encoder_cfg.input_mask_dtype)
if encoder_type == "albert":
return encoder_cls(
return networks.AlbertEncoder(
vocab_size=encoder_cfg.vocab_size,
embedding_width=encoder_cfg.embedding_width,
hidden_size=encoder_cfg.hidden_size,
......@@ -282,26 +271,55 @@ def build_encoder(config: EncoderConfig,
dict_outputs=True)
if encoder_type == "bigbird":
return encoder_cls(
# TODO(frederickliu): Support use_gradient_checkpointing.
if encoder_cfg.use_gradient_checkpointing:
raise ValueError("Gradient checkpointing unsupported at the moment.")
embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
max_seq_length=encoder_cfg.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
dropout_rate=encoder_cfg.dropout_rate)
attention_cfg = dict(
num_heads=encoder_cfg.num_attention_heads,
key_dim=int(encoder_cfg.hidden_size // encoder_cfg.num_attention_heads),
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
max_rand_mask_length=encoder_cfg.max_position_embeddings,
num_rand_blocks=encoder_cfg.num_rand_blocks,
from_block_size=encoder_cfg.block_size,
to_block_size=encoder_cfg.block_size,
)
hidden_cfg = dict(
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
intermediate_activation=tf_utils.get_activation(
encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
num_rand_blocks=encoder_cfg.num_rand_blocks,
block_size=encoder_cfg.block_size,
max_position_embeddings=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
embedding_width=encoder_cfg.embedding_width,
use_gradient_checkpointing=encoder_cfg.use_gradient_checkpointing)
attention_cls=bigbird_attention.BigBirdAttention,
attention_cfg=attention_cfg)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cls=layers.TransformerScaffold,
hidden_cfg=hidden_cfg,
num_hidden_instances=encoder_cfg.num_layers,
mask_cls=bigbird_attention.BigBirdMasks,
mask_cfg=dict(block_size=encoder_cfg.block_size),
pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
return_all_layer_outputs=False,
dict_outputs=True,
layer_idx_as_attention_seed=True)
return networks.EncoderScaffold(**kwargs)
if encoder_type == "xlnet":
return encoder_cls(
return networks.XLNetBase(
vocab_size=encoder_cfg.vocab_size,
num_layers=encoder_cfg.num_layers,
hidden_size=encoder_cfg.hidden_size,
......@@ -325,7 +343,7 @@ def build_encoder(config: EncoderConfig,
# Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type.
return encoder_cls(
return networks.BertEncoder(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
......
......@@ -14,6 +14,7 @@
"""Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
import copy
import inspect
from absl import logging
......@@ -86,12 +87,19 @@ class EncoderScaffold(tf.keras.Model):
`dropout_rate`: The overall dropout rate for the transformer layers.
`attention_dropout_rate`: The dropout rate for the attention layers.
`kernel_initializer`: The initializer for the transformer layers.
mask_cls: The class to generate masks passed into hidden_cls() from inputs
and 2D mask indicating positions we can attend to. It is the caller's job
to make sure the output of the mask_layer can be used by hidden_layer.
A mask_cls is usually mapped to a hidden_cls.
mask_cfg: A dict of kwargs pass to mask_cls.
layer_norm_before_pooling: Whether to add a layer norm before the pooling
layer. You probably want to turn this on if you set `norm_first=True` in
transformer layers.
return_all_layer_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers.
dict_outputs: Whether to use a dictionary as the model outputs.
layer_idx_as_attention_seed: Whether to include layer_idx in
attention_cfg in hidden_cfg.
"""
def __init__(self,
......@@ -104,9 +112,12 @@ class EncoderScaffold(tf.keras.Model):
num_hidden_instances=1,
hidden_cls=layers.Transformer,
hidden_cfg=None,
mask_cls=keras_nlp.layers.SelfAttentionMask,
mask_cfg=None,
layer_norm_before_pooling=False,
return_all_layer_outputs=False,
dict_outputs=False,
layer_idx_as_attention_seed=False,
**kwargs):
if embedding_cls:
......@@ -169,15 +180,25 @@ class EncoderScaffold(tf.keras.Model):
tf.keras.layers.Dropout(
rate=embedding_cfg['dropout_rate'])(embeddings))
attention_mask = keras_nlp.layers.SelfAttentionMask()(embeddings, mask)
mask_cfg = {} if mask_cfg is None else mask_cfg
if inspect.isclass(mask_cls):
mask_layer = mask_cls(**mask_cfg)
else:
mask_layer = mask_cls
attention_mask = mask_layer(embeddings, mask)
data = embeddings
layer_output_data = []
hidden_layers = []
for _ in range(num_hidden_instances):
hidden_cfg = hidden_cfg if hidden_cfg else {}
for i in range(num_hidden_instances):
if inspect.isclass(hidden_cls):
layer = hidden_cls(**hidden_cfg) if hidden_cfg else hidden_cls()
if hidden_cfg and 'attention_cfg' in hidden_cfg and (
layer_idx_as_attention_seed):
hidden_cfg = copy.deepcopy(hidden_cfg)
hidden_cfg['attention_cfg']['seed'] = i
layer = hidden_cls(**hidden_cfg)
else:
layer = hidden_cls
data = layer([data, attention_mask])
......@@ -227,6 +248,8 @@ class EncoderScaffold(tf.keras.Model):
self._hidden_cls = hidden_cls
self._hidden_cfg = hidden_cfg
self._mask_cls = mask_cls
self._mask_cfg = mask_cfg
self._num_hidden_instances = num_hidden_instances
self._pooled_output_dim = pooled_output_dim
self._pooler_layer_initializer = pooler_layer_initializer
......@@ -247,6 +270,7 @@ class EncoderScaffold(tf.keras.Model):
if self._layer_norm_before_pooling:
self._output_layer_norm = output_layer_norm
self._pooler_layer = pooler_layer
self._layer_idx_as_attention_seed = layer_idx_as_attention_seed
logging.info('EncoderScaffold configs: %s', self.get_config())
......@@ -260,32 +284,48 @@ class EncoderScaffold(tf.keras.Model):
'layer_norm_before_pooling': self._layer_norm_before_pooling,
'return_all_layer_outputs': self._return_all_layer_outputs,
'dict_outputs': self._dict_outputs,
'layer_idx_as_attention_seed': self._layer_idx_as_attention_seed
}
cfgs = {
'hidden_cfg': self._hidden_cfg,
'mask_cfg': self._mask_cfg
}
if self._hidden_cfg:
config_dict['hidden_cfg'] = {}
for k, v in self._hidden_cfg.items():
for cfg_name, cfg in cfgs.items():
if cfg:
config_dict[cfg_name] = {}
for k, v in cfg.items():
# `self._hidden_cfg` may contain `class`, e.g., when `hidden_cfg` is
# `TransformerScaffold`, its `attention_cls` argument can be a `class`.
# `TransformerScaffold`, `attention_cls` argument can be a `class`.
if inspect.isclass(v):
config_dict['hidden_cfg'][k] = tf.keras.utils.get_registered_name(v)
config_dict[cfg_name][k] = tf.keras.utils.get_registered_name(v)
else:
config_dict['hidden_cfg'][k] = v
config_dict[cfg_name][k] = v
clss = {
'hidden_cls': self._hidden_cls,
'mask_cls': self._mask_cls
}
if inspect.isclass(self._hidden_cls):
config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name(
self._hidden_cls)
for cls_name, cls in clss.items():
if inspect.isclass(cls):
key = '{}_string'.format(cls_name)
config_dict[key] = tf.keras.utils.get_registered_name(cls)
else:
config_dict['hidden_cls'] = self._hidden_cls
config_dict[cls_name] = cls
config_dict.update(self._kwargs)
return config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
if 'hidden_cls_string' in config:
config['hidden_cls'] = tf.keras.utils.get_registered_object(
config['hidden_cls_string'], custom_objects=custom_objects)
del config['hidden_cls_string']
cls_names = ['hidden_cls', 'mask_cls']
for cls_name in cls_names:
cls_string = '{}_string'.format(cls_name)
if cls_string in config:
config[cls_name] = tf.keras.utils.get_registered_object(
config[cls_string], custom_objects=custom_objects)
del config[cls_string]
return cls(**config)
def get_embedding_table(self):
......
......@@ -20,6 +20,7 @@ import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.modeling import activations
from official.nlp import keras_nlp
from official.nlp.modeling import layers
from official.nlp.modeling.networks import encoder_scaffold
......@@ -47,6 +48,30 @@ class ValidatedTransformerLayer(layers.Transformer):
return config
# Test class that wraps a standard self attention mask layer.
# If this layer is called at any point, the list passed to the config
# object will be filled with a
# boolean 'True'. We register this class as a Keras serializable so we can
# test serialization below.
@tf.keras.utils.register_keras_serializable(package="TestOnly")
class ValidatedMaskLayer(keras_nlp.layers.SelfAttentionMask):
def __init__(self, call_list, call_class=None, **kwargs):
super(ValidatedMaskLayer, self).__init__(**kwargs)
self.list = call_list
self.call_class = call_class
def call(self, inputs, mask):
self.list.append(True)
return super(ValidatedMaskLayer, self).call(inputs, mask)
def get_config(self):
config = super(ValidatedMaskLayer, self).get_config()
config["call_list"] = self.list
config["call_class"] = tf.keras.utils.get_registered_name(self.call_class)
return config
@tf.keras.utils.register_keras_serializable(package="TestLayerOnly")
class TestLayer(tf.keras.layers.Layer):
pass
......@@ -95,6 +120,11 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
"call_list":
call_list
}
mask_call_list = []
mask_cfg = {
"call_list":
mask_call_list
}
# Create a small EncoderScaffold for testing.
test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=num_hidden_instances,
......@@ -103,6 +133,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
stddev=0.02),
hidden_cls=ValidatedTransformerLayer,
hidden_cfg=hidden_cfg,
mask_cls=ValidatedMaskLayer,
mask_cfg=mask_cfg,
embedding_cfg=embedding_cfg,
layer_norm_before_pooling=True,
return_all_layer_outputs=return_all_layer_outputs)
......@@ -530,10 +562,15 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
"call_list":
call_list
}
mask_call_list = []
mask_cfg = {
"call_list": mask_call_list
}
# Create a small EncoderScaffold for testing. This time, we pass an already-
# instantiated layer object.
xformer = ValidatedTransformerLayer(**hidden_cfg)
xmask = ValidatedMaskLayer(**mask_cfg)
test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3,
......@@ -541,6 +578,7 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
hidden_cls=xformer,
mask_cls=xmask,
embedding_cfg=embedding_cfg)
# Create the inputs (note that the first dimension is implicit).
......@@ -603,6 +641,8 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
"call_class":
TestLayer
}
mask_call_list = []
mask_cfg = {"call_list": mask_call_list, "call_class": TestLayer}
# Create a small EncoderScaffold for testing. This time, we pass an already-
# instantiated layer object.
kwargs = dict(
......@@ -614,11 +654,16 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
if use_hidden_cls_instance:
xformer = ValidatedTransformerLayer(**hidden_cfg)
xmask = ValidatedMaskLayer(**mask_cfg)
test_network = encoder_scaffold.EncoderScaffold(
hidden_cls=xformer, **kwargs)
hidden_cls=xformer, mask_cls=xmask, **kwargs)
else:
test_network = encoder_scaffold.EncoderScaffold(
hidden_cls=ValidatedTransformerLayer, hidden_cfg=hidden_cfg, **kwargs)
hidden_cls=ValidatedTransformerLayer,
hidden_cfg=hidden_cfg,
mask_cls=ValidatedMaskLayer,
mask_cfg=mask_cfg,
**kwargs)
# Create another network object from the first object's config.
new_network = encoder_scaffold.EncoderScaffold.from_config(
......
......@@ -375,14 +375,15 @@ class BigBirdMasks(tf.keras.layers.Layer):
super().__init__(**kwargs)
self._block_size = block_size
def call(self, inputs):
encoder_shape = tf.shape(inputs)
def call(self, inputs, mask):
encoder_shape = tf.shape(mask)
mask = tf.cast(mask, inputs.dtype)
batch_size, seq_length = encoder_shape[0], encoder_shape[1]
# reshape for blocking
blocked_encoder_mask = tf.reshape(
inputs, (batch_size, seq_length // self._block_size, self._block_size))
encoder_from_mask = tf.reshape(inputs, (batch_size, 1, seq_length, 1))
encoder_to_mask = tf.reshape(inputs, (batch_size, 1, 1, seq_length))
mask, (batch_size, seq_length // self._block_size, self._block_size))
encoder_from_mask = tf.reshape(mask, (batch_size, 1, seq_length, 1))
encoder_to_mask = tf.reshape(mask, (batch_size, 1, 1, seq_length))
band_mask = create_band_mask_from_inputs(blocked_encoder_mask,
blocked_encoder_mask)
......
......@@ -29,7 +29,6 @@ class BigbirdAttentionTest(tf.test.TestCase):
block_size = 64
mask_layer = attention.BigBirdMasks(block_size=block_size)
encoder_inputs_mask = tf.zeros((batch_size, seq_length), dtype=tf.int32)
masks = mask_layer(tf.cast(encoder_inputs_mask, dtype=tf.float64))
test_layer = attention.BigBirdAttention(
num_heads=num_heads,
key_dim=key_dim,
......@@ -38,6 +37,7 @@ class BigbirdAttentionTest(tf.test.TestCase):
seed=0)
query = tf.random.normal(
shape=(batch_size, seq_length, key_dim))
masks = mask_layer(query, tf.cast(encoder_inputs_mask, dtype=tf.float64))
value = query
output = test_layer(
query=query,
......
......@@ -177,8 +177,7 @@ class BigBirdEncoder(tf.keras.Model):
self._transformer_layers = []
data = embeddings
masks = attention.BigBirdMasks(block_size=block_size)(
tf.cast(mask, embeddings.dtype))
masks = attention.BigBirdMasks(block_size=block_size)(data, mask)
encoder_outputs = []
attn_head_dim = hidden_size // num_attention_heads
for i in range(num_layers):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment