Commit 102f267e authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 374236491
parent 34731381
...@@ -18,15 +18,15 @@ Includes configurations and factory methods. ...@@ -18,15 +18,15 @@ Includes configurations and factory methods.
""" """
from typing import Optional from typing import Optional
from absl import logging
import dataclasses import dataclasses
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.projects.bigbird import encoder as bigbird_encoder from official.nlp.projects.bigbird import attention as bigbird_attention
@dataclasses.dataclass @dataclasses.dataclass
...@@ -177,15 +177,6 @@ class EncoderConfig(hyperparams.OneOfConfig): ...@@ -177,15 +177,6 @@ class EncoderConfig(hyperparams.OneOfConfig):
xlnet: XLNetEncoderConfig = XLNetEncoderConfig() xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
ENCODER_CLS = {
"bert": networks.BertEncoder,
"mobilebert": networks.MobileBERTEncoder,
"albert": networks.AlbertEncoder,
"bigbird": bigbird_encoder.BigBirdEncoder,
"xlnet": networks.XLNetBase,
}
@gin.configurable @gin.configurable
def build_encoder(config: EncoderConfig, def build_encoder(config: EncoderConfig,
embedding_layer: Optional[tf.keras.layers.Layer] = None, embedding_layer: Optional[tf.keras.layers.Layer] = None,
...@@ -205,13 +196,11 @@ def build_encoder(config: EncoderConfig, ...@@ -205,13 +196,11 @@ def build_encoder(config: EncoderConfig,
Returns: Returns:
An encoder instance. An encoder instance.
""" """
encoder_type = config.type
encoder_cfg = config.get()
encoder_cls = encoder_cls or ENCODER_CLS[encoder_type]
logging.info("Encoder class: %s to build...", encoder_cls.__name__)
if bypass_config: if bypass_config:
return encoder_cls() return encoder_cls()
if encoder_cls.__name__ == "EncoderScaffold": encoder_type = config.type
encoder_cfg = config.get()
if encoder_cls and encoder_cls.__name__ == "EncoderScaffold":
embedding_cfg = dict( embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size, vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size, type_vocab_size=encoder_cfg.type_vocab_size,
...@@ -243,7 +232,7 @@ def build_encoder(config: EncoderConfig, ...@@ -243,7 +232,7 @@ def build_encoder(config: EncoderConfig,
return encoder_cls(**kwargs) return encoder_cls(**kwargs)
if encoder_type == "mobilebert": if encoder_type == "mobilebert":
return encoder_cls( return networks.MobileBERTEncoder(
word_vocab_size=encoder_cfg.word_vocab_size, word_vocab_size=encoder_cfg.word_vocab_size,
word_embed_size=encoder_cfg.word_embed_size, word_embed_size=encoder_cfg.word_embed_size,
type_vocab_size=encoder_cfg.type_vocab_size, type_vocab_size=encoder_cfg.type_vocab_size,
...@@ -265,7 +254,7 @@ def build_encoder(config: EncoderConfig, ...@@ -265,7 +254,7 @@ def build_encoder(config: EncoderConfig,
input_mask_dtype=encoder_cfg.input_mask_dtype) input_mask_dtype=encoder_cfg.input_mask_dtype)
if encoder_type == "albert": if encoder_type == "albert":
return encoder_cls( return networks.AlbertEncoder(
vocab_size=encoder_cfg.vocab_size, vocab_size=encoder_cfg.vocab_size,
embedding_width=encoder_cfg.embedding_width, embedding_width=encoder_cfg.embedding_width,
hidden_size=encoder_cfg.hidden_size, hidden_size=encoder_cfg.hidden_size,
...@@ -282,26 +271,55 @@ def build_encoder(config: EncoderConfig, ...@@ -282,26 +271,55 @@ def build_encoder(config: EncoderConfig,
dict_outputs=True) dict_outputs=True)
if encoder_type == "bigbird": if encoder_type == "bigbird":
return encoder_cls( # TODO(frederickliu): Support use_gradient_checkpointing.
if encoder_cfg.use_gradient_checkpointing:
raise ValueError("Gradient checkpointing unsupported at the moment.")
embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size, vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=encoder_cfg.hidden_size, hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers, max_seq_length=encoder_cfg.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
dropout_rate=encoder_cfg.dropout_rate)
attention_cfg = dict(
num_heads=encoder_cfg.num_attention_heads,
key_dim=int(encoder_cfg.hidden_size // encoder_cfg.num_attention_heads),
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
max_rand_mask_length=encoder_cfg.max_position_embeddings,
num_rand_blocks=encoder_cfg.num_rand_blocks,
from_block_size=encoder_cfg.block_size,
to_block_size=encoder_cfg.block_size,
)
hidden_cfg = dict(
num_attention_heads=encoder_cfg.num_attention_heads, num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size, intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(encoder_cfg.hidden_activation), intermediate_activation=tf_utils.get_activation(
encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate, dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate,
num_rand_blocks=encoder_cfg.num_rand_blocks, kernel_initializer=tf.keras.initializers.TruncatedNormal(
block_size=encoder_cfg.block_size,
max_position_embeddings=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
embedding_width=encoder_cfg.embedding_width, attention_cls=bigbird_attention.BigBirdAttention,
use_gradient_checkpointing=encoder_cfg.use_gradient_checkpointing) attention_cfg=attention_cfg)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cls=layers.TransformerScaffold,
hidden_cfg=hidden_cfg,
num_hidden_instances=encoder_cfg.num_layers,
mask_cls=bigbird_attention.BigBirdMasks,
mask_cfg=dict(block_size=encoder_cfg.block_size),
pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
return_all_layer_outputs=False,
dict_outputs=True,
layer_idx_as_attention_seed=True)
return networks.EncoderScaffold(**kwargs)
if encoder_type == "xlnet": if encoder_type == "xlnet":
return encoder_cls( return networks.XLNetBase(
vocab_size=encoder_cfg.vocab_size, vocab_size=encoder_cfg.vocab_size,
num_layers=encoder_cfg.num_layers, num_layers=encoder_cfg.num_layers,
hidden_size=encoder_cfg.hidden_size, hidden_size=encoder_cfg.hidden_size,
...@@ -325,7 +343,7 @@ def build_encoder(config: EncoderConfig, ...@@ -325,7 +343,7 @@ def build_encoder(config: EncoderConfig,
# Uses the default BERTEncoder configuration schema to create the encoder. # Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type. # If it does not match, please add a switch branch by the encoder type.
return encoder_cls( return networks.BertEncoder(
vocab_size=encoder_cfg.vocab_size, vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size, hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers, num_layers=encoder_cfg.num_layers,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Transformer-based text encoder network.""" """Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
import copy
import inspect import inspect
from absl import logging from absl import logging
...@@ -86,12 +87,19 @@ class EncoderScaffold(tf.keras.Model): ...@@ -86,12 +87,19 @@ class EncoderScaffold(tf.keras.Model):
`dropout_rate`: The overall dropout rate for the transformer layers. `dropout_rate`: The overall dropout rate for the transformer layers.
`attention_dropout_rate`: The dropout rate for the attention layers. `attention_dropout_rate`: The dropout rate for the attention layers.
`kernel_initializer`: The initializer for the transformer layers. `kernel_initializer`: The initializer for the transformer layers.
mask_cls: The class to generate masks passed into hidden_cls() from inputs
and 2D mask indicating positions we can attend to. It is the caller's job
to make sure the output of the mask_layer can be used by hidden_layer.
A mask_cls is usually mapped to a hidden_cls.
mask_cfg: A dict of kwargs pass to mask_cls.
layer_norm_before_pooling: Whether to add a layer norm before the pooling layer_norm_before_pooling: Whether to add a layer norm before the pooling
layer. You probably want to turn this on if you set `norm_first=True` in layer. You probably want to turn this on if you set `norm_first=True` in
transformer layers. transformer layers.
return_all_layer_outputs: Whether to output sequence embedding outputs of return_all_layer_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers. all encoder transformer layers.
dict_outputs: Whether to use a dictionary as the model outputs. dict_outputs: Whether to use a dictionary as the model outputs.
layer_idx_as_attention_seed: Whether to include layer_idx in
attention_cfg in hidden_cfg.
""" """
def __init__(self, def __init__(self,
...@@ -104,9 +112,12 @@ class EncoderScaffold(tf.keras.Model): ...@@ -104,9 +112,12 @@ class EncoderScaffold(tf.keras.Model):
num_hidden_instances=1, num_hidden_instances=1,
hidden_cls=layers.Transformer, hidden_cls=layers.Transformer,
hidden_cfg=None, hidden_cfg=None,
mask_cls=keras_nlp.layers.SelfAttentionMask,
mask_cfg=None,
layer_norm_before_pooling=False, layer_norm_before_pooling=False,
return_all_layer_outputs=False, return_all_layer_outputs=False,
dict_outputs=False, dict_outputs=False,
layer_idx_as_attention_seed=False,
**kwargs): **kwargs):
if embedding_cls: if embedding_cls:
...@@ -169,15 +180,25 @@ class EncoderScaffold(tf.keras.Model): ...@@ -169,15 +180,25 @@ class EncoderScaffold(tf.keras.Model):
tf.keras.layers.Dropout( tf.keras.layers.Dropout(
rate=embedding_cfg['dropout_rate'])(embeddings)) rate=embedding_cfg['dropout_rate'])(embeddings))
attention_mask = keras_nlp.layers.SelfAttentionMask()(embeddings, mask) mask_cfg = {} if mask_cfg is None else mask_cfg
if inspect.isclass(mask_cls):
mask_layer = mask_cls(**mask_cfg)
else:
mask_layer = mask_cls
attention_mask = mask_layer(embeddings, mask)
data = embeddings data = embeddings
layer_output_data = [] layer_output_data = []
hidden_layers = [] hidden_layers = []
for _ in range(num_hidden_instances): hidden_cfg = hidden_cfg if hidden_cfg else {}
for i in range(num_hidden_instances):
if inspect.isclass(hidden_cls): if inspect.isclass(hidden_cls):
layer = hidden_cls(**hidden_cfg) if hidden_cfg else hidden_cls() if hidden_cfg and 'attention_cfg' in hidden_cfg and (
layer_idx_as_attention_seed):
hidden_cfg = copy.deepcopy(hidden_cfg)
hidden_cfg['attention_cfg']['seed'] = i
layer = hidden_cls(**hidden_cfg)
else: else:
layer = hidden_cls layer = hidden_cls
data = layer([data, attention_mask]) data = layer([data, attention_mask])
...@@ -227,6 +248,8 @@ class EncoderScaffold(tf.keras.Model): ...@@ -227,6 +248,8 @@ class EncoderScaffold(tf.keras.Model):
self._hidden_cls = hidden_cls self._hidden_cls = hidden_cls
self._hidden_cfg = hidden_cfg self._hidden_cfg = hidden_cfg
self._mask_cls = mask_cls
self._mask_cfg = mask_cfg
self._num_hidden_instances = num_hidden_instances self._num_hidden_instances = num_hidden_instances
self._pooled_output_dim = pooled_output_dim self._pooled_output_dim = pooled_output_dim
self._pooler_layer_initializer = pooler_layer_initializer self._pooler_layer_initializer = pooler_layer_initializer
...@@ -247,6 +270,7 @@ class EncoderScaffold(tf.keras.Model): ...@@ -247,6 +270,7 @@ class EncoderScaffold(tf.keras.Model):
if self._layer_norm_before_pooling: if self._layer_norm_before_pooling:
self._output_layer_norm = output_layer_norm self._output_layer_norm = output_layer_norm
self._pooler_layer = pooler_layer self._pooler_layer = pooler_layer
self._layer_idx_as_attention_seed = layer_idx_as_attention_seed
logging.info('EncoderScaffold configs: %s', self.get_config()) logging.info('EncoderScaffold configs: %s', self.get_config())
...@@ -260,32 +284,48 @@ class EncoderScaffold(tf.keras.Model): ...@@ -260,32 +284,48 @@ class EncoderScaffold(tf.keras.Model):
'layer_norm_before_pooling': self._layer_norm_before_pooling, 'layer_norm_before_pooling': self._layer_norm_before_pooling,
'return_all_layer_outputs': self._return_all_layer_outputs, 'return_all_layer_outputs': self._return_all_layer_outputs,
'dict_outputs': self._dict_outputs, 'dict_outputs': self._dict_outputs,
'layer_idx_as_attention_seed': self._layer_idx_as_attention_seed
} }
if self._hidden_cfg: cfgs = {
config_dict['hidden_cfg'] = {} 'hidden_cfg': self._hidden_cfg,
for k, v in self._hidden_cfg.items(): 'mask_cfg': self._mask_cfg
# `self._hidden_cfg` may contain `class`, e.g., when `hidden_cfg` is }
# `TransformerScaffold`, its `attention_cls` argument can be a `class`.
if inspect.isclass(v): for cfg_name, cfg in cfgs.items():
config_dict['hidden_cfg'][k] = tf.keras.utils.get_registered_name(v) if cfg:
else: config_dict[cfg_name] = {}
config_dict['hidden_cfg'][k] = v for k, v in cfg.items():
# `self._hidden_cfg` may contain `class`, e.g., when `hidden_cfg` is
if inspect.isclass(self._hidden_cls): # `TransformerScaffold`, `attention_cls` argument can be a `class`.
config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name( if inspect.isclass(v):
self._hidden_cls) config_dict[cfg_name][k] = tf.keras.utils.get_registered_name(v)
else: else:
config_dict['hidden_cls'] = self._hidden_cls config_dict[cfg_name][k] = v
clss = {
'hidden_cls': self._hidden_cls,
'mask_cls': self._mask_cls
}
for cls_name, cls in clss.items():
if inspect.isclass(cls):
key = '{}_string'.format(cls_name)
config_dict[key] = tf.keras.utils.get_registered_name(cls)
else:
config_dict[cls_name] = cls
config_dict.update(self._kwargs) config_dict.update(self._kwargs)
return config_dict return config_dict
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
if 'hidden_cls_string' in config: cls_names = ['hidden_cls', 'mask_cls']
config['hidden_cls'] = tf.keras.utils.get_registered_object( for cls_name in cls_names:
config['hidden_cls_string'], custom_objects=custom_objects) cls_string = '{}_string'.format(cls_name)
del config['hidden_cls_string'] if cls_string in config:
config[cls_name] = tf.keras.utils.get_registered_object(
config[cls_string], custom_objects=custom_objects)
del config[cls_string]
return cls(**config) return cls(**config)
def get_embedding_table(self): def get_embedding_table(self):
......
...@@ -20,6 +20,7 @@ import tensorflow as tf ...@@ -20,6 +20,7 @@ import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.modeling import activations from official.modeling import activations
from official.nlp import keras_nlp
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling.networks import encoder_scaffold from official.nlp.modeling.networks import encoder_scaffold
...@@ -47,6 +48,30 @@ class ValidatedTransformerLayer(layers.Transformer): ...@@ -47,6 +48,30 @@ class ValidatedTransformerLayer(layers.Transformer):
return config return config
# Test class that wraps a standard self attention mask layer.
# If this layer is called at any point, the list passed to the config
# object will be filled with a
# boolean 'True'. We register this class as a Keras serializable so we can
# test serialization below.
@tf.keras.utils.register_keras_serializable(package="TestOnly")
class ValidatedMaskLayer(keras_nlp.layers.SelfAttentionMask):
def __init__(self, call_list, call_class=None, **kwargs):
super(ValidatedMaskLayer, self).__init__(**kwargs)
self.list = call_list
self.call_class = call_class
def call(self, inputs, mask):
self.list.append(True)
return super(ValidatedMaskLayer, self).call(inputs, mask)
def get_config(self):
config = super(ValidatedMaskLayer, self).get_config()
config["call_list"] = self.list
config["call_class"] = tf.keras.utils.get_registered_name(self.call_class)
return config
@tf.keras.utils.register_keras_serializable(package="TestLayerOnly") @tf.keras.utils.register_keras_serializable(package="TestLayerOnly")
class TestLayer(tf.keras.layers.Layer): class TestLayer(tf.keras.layers.Layer):
pass pass
...@@ -95,6 +120,11 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -95,6 +120,11 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
"call_list": "call_list":
call_list call_list
} }
mask_call_list = []
mask_cfg = {
"call_list":
mask_call_list
}
# Create a small EncoderScaffold for testing. # Create a small EncoderScaffold for testing.
test_network = encoder_scaffold.EncoderScaffold( test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=num_hidden_instances, num_hidden_instances=num_hidden_instances,
...@@ -103,6 +133,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -103,6 +133,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
stddev=0.02), stddev=0.02),
hidden_cls=ValidatedTransformerLayer, hidden_cls=ValidatedTransformerLayer,
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
mask_cls=ValidatedMaskLayer,
mask_cfg=mask_cfg,
embedding_cfg=embedding_cfg, embedding_cfg=embedding_cfg,
layer_norm_before_pooling=True, layer_norm_before_pooling=True,
return_all_layer_outputs=return_all_layer_outputs) return_all_layer_outputs=return_all_layer_outputs)
...@@ -530,10 +562,15 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase): ...@@ -530,10 +562,15 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
"call_list": "call_list":
call_list call_list
} }
mask_call_list = []
mask_cfg = {
"call_list": mask_call_list
}
# Create a small EncoderScaffold for testing. This time, we pass an already- # Create a small EncoderScaffold for testing. This time, we pass an already-
# instantiated layer object. # instantiated layer object.
xformer = ValidatedTransformerLayer(**hidden_cfg) xformer = ValidatedTransformerLayer(**hidden_cfg)
xmask = ValidatedMaskLayer(**mask_cfg)
test_network = encoder_scaffold.EncoderScaffold( test_network = encoder_scaffold.EncoderScaffold(
num_hidden_instances=3, num_hidden_instances=3,
...@@ -541,6 +578,7 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase): ...@@ -541,6 +578,7 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02), stddev=0.02),
hidden_cls=xformer, hidden_cls=xformer,
mask_cls=xmask,
embedding_cfg=embedding_cfg) embedding_cfg=embedding_cfg)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
...@@ -603,6 +641,8 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase): ...@@ -603,6 +641,8 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
"call_class": "call_class":
TestLayer TestLayer
} }
mask_call_list = []
mask_cfg = {"call_list": mask_call_list, "call_class": TestLayer}
# Create a small EncoderScaffold for testing. This time, we pass an already- # Create a small EncoderScaffold for testing. This time, we pass an already-
# instantiated layer object. # instantiated layer object.
kwargs = dict( kwargs = dict(
...@@ -614,11 +654,16 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase): ...@@ -614,11 +654,16 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
if use_hidden_cls_instance: if use_hidden_cls_instance:
xformer = ValidatedTransformerLayer(**hidden_cfg) xformer = ValidatedTransformerLayer(**hidden_cfg)
xmask = ValidatedMaskLayer(**mask_cfg)
test_network = encoder_scaffold.EncoderScaffold( test_network = encoder_scaffold.EncoderScaffold(
hidden_cls=xformer, **kwargs) hidden_cls=xformer, mask_cls=xmask, **kwargs)
else: else:
test_network = encoder_scaffold.EncoderScaffold( test_network = encoder_scaffold.EncoderScaffold(
hidden_cls=ValidatedTransformerLayer, hidden_cfg=hidden_cfg, **kwargs) hidden_cls=ValidatedTransformerLayer,
hidden_cfg=hidden_cfg,
mask_cls=ValidatedMaskLayer,
mask_cfg=mask_cfg,
**kwargs)
# Create another network object from the first object's config. # Create another network object from the first object's config.
new_network = encoder_scaffold.EncoderScaffold.from_config( new_network = encoder_scaffold.EncoderScaffold.from_config(
......
...@@ -375,14 +375,15 @@ class BigBirdMasks(tf.keras.layers.Layer): ...@@ -375,14 +375,15 @@ class BigBirdMasks(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self._block_size = block_size self._block_size = block_size
def call(self, inputs): def call(self, inputs, mask):
encoder_shape = tf.shape(inputs) encoder_shape = tf.shape(mask)
mask = tf.cast(mask, inputs.dtype)
batch_size, seq_length = encoder_shape[0], encoder_shape[1] batch_size, seq_length = encoder_shape[0], encoder_shape[1]
# reshape for blocking # reshape for blocking
blocked_encoder_mask = tf.reshape( blocked_encoder_mask = tf.reshape(
inputs, (batch_size, seq_length // self._block_size, self._block_size)) mask, (batch_size, seq_length // self._block_size, self._block_size))
encoder_from_mask = tf.reshape(inputs, (batch_size, 1, seq_length, 1)) encoder_from_mask = tf.reshape(mask, (batch_size, 1, seq_length, 1))
encoder_to_mask = tf.reshape(inputs, (batch_size, 1, 1, seq_length)) encoder_to_mask = tf.reshape(mask, (batch_size, 1, 1, seq_length))
band_mask = create_band_mask_from_inputs(blocked_encoder_mask, band_mask = create_band_mask_from_inputs(blocked_encoder_mask,
blocked_encoder_mask) blocked_encoder_mask)
......
...@@ -29,7 +29,6 @@ class BigbirdAttentionTest(tf.test.TestCase): ...@@ -29,7 +29,6 @@ class BigbirdAttentionTest(tf.test.TestCase):
block_size = 64 block_size = 64
mask_layer = attention.BigBirdMasks(block_size=block_size) mask_layer = attention.BigBirdMasks(block_size=block_size)
encoder_inputs_mask = tf.zeros((batch_size, seq_length), dtype=tf.int32) encoder_inputs_mask = tf.zeros((batch_size, seq_length), dtype=tf.int32)
masks = mask_layer(tf.cast(encoder_inputs_mask, dtype=tf.float64))
test_layer = attention.BigBirdAttention( test_layer = attention.BigBirdAttention(
num_heads=num_heads, num_heads=num_heads,
key_dim=key_dim, key_dim=key_dim,
...@@ -38,6 +37,7 @@ class BigbirdAttentionTest(tf.test.TestCase): ...@@ -38,6 +37,7 @@ class BigbirdAttentionTest(tf.test.TestCase):
seed=0) seed=0)
query = tf.random.normal( query = tf.random.normal(
shape=(batch_size, seq_length, key_dim)) shape=(batch_size, seq_length, key_dim))
masks = mask_layer(query, tf.cast(encoder_inputs_mask, dtype=tf.float64))
value = query value = query
output = test_layer( output = test_layer(
query=query, query=query,
......
...@@ -177,8 +177,7 @@ class BigBirdEncoder(tf.keras.Model): ...@@ -177,8 +177,7 @@ class BigBirdEncoder(tf.keras.Model):
self._transformer_layers = [] self._transformer_layers = []
data = embeddings data = embeddings
masks = attention.BigBirdMasks(block_size=block_size)( masks = attention.BigBirdMasks(block_size=block_size)(data, mask)
tf.cast(mask, embeddings.dtype))
encoder_outputs = [] encoder_outputs = []
attn_head_dim = hidden_size // num_attention_heads attn_head_dim = hidden_size // num_attention_heads
for i in range(num_layers): for i in range(num_layers):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment