Commit 356c98bd authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge remote-tracking branch 'upstream/master' into detr-push-3

parents d31aba8a b9785623
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for BERT configurations and models instantiation."""
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import encoders
class BertModelsTest(tf.test.TestCase):
def test_network_invocation(self):
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1))
_ = bert.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads.
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = bert.instantiate_pretrainer_from_cfg(config)
with self.assertRaises(ValueError):
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=1),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence"),
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = bert.instantiate_pretrainer_from_cfg(config)
def test_checkpoint_items(self):
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
encoder = bert.instantiate_pretrainer_from_cfg(config)
self.assertSameElements(
encoder.checkpoint_items.keys(),
["encoder", "masked_lm", "next_sentence.pooler_dense"])
if __name__ == "__main__":
tf.test.main()
...@@ -14,21 +14,17 @@ ...@@ -14,21 +14,17 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""ELECTRA model configurations and instantiation methods.""" """ELECTRA model configurations and instantiation methods."""
from typing import List, Optional from typing import List
import dataclasses import dataclasses
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling.models import electra_pretrainer
@dataclasses.dataclass @dataclasses.dataclass
class ELECTRAPretrainerConfig(base_config.Config): class ElectraPretrainerConfig(base_config.Config):
"""ELECTRA pretrainer configuration.""" """ELECTRA pretrainer configuration."""
num_masked_tokens: int = 76 num_masked_tokens: int = 76
sequence_length: int = 512 sequence_length: int = 512
...@@ -36,56 +32,6 @@ class ELECTRAPretrainerConfig(base_config.Config): ...@@ -36,56 +32,6 @@ class ELECTRAPretrainerConfig(base_config.Config):
discriminator_loss_weight: float = 50.0 discriminator_loss_weight: float = 50.0
tie_embeddings: bool = True tie_embeddings: bool = True
disallow_correct: bool = False disallow_correct: bool = False
generator_encoder: encoders.TransformerEncoderConfig = ( generator_encoder: encoders.EncoderConfig = encoders.EncoderConfig()
encoders.TransformerEncoderConfig()) discriminator_encoder: encoders.EncoderConfig = encoders.EncoderConfig()
discriminator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
cls_heads: List[bert.ClsHeadConfig] = dataclasses.field(default_factory=list) cls_heads: List[bert.ClsHeadConfig] = dataclasses.field(default_factory=list)
def instantiate_classification_heads_from_cfgs(
cls_head_configs: List[bert.ClsHeadConfig]
) -> List[layers.ClassificationHead]:
if cls_head_configs:
return [
layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
]
else:
return []
def instantiate_pretrainer_from_cfg(
config: ELECTRAPretrainerConfig,
generator_network: Optional[tf.keras.Model] = None,
discriminator_network: Optional[tf.keras.Model] = None,
) -> electra_pretrainer.ElectraPretrainer:
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg = config.generator_encoder
discriminator_encoder_cfg = config.discriminator_encoder
# Copy discriminator's embeddings to generator for easier model serialization.
if discriminator_network is None:
discriminator_network = encoders.instantiate_encoder_from_cfg(
discriminator_encoder_cfg)
if generator_network is None:
if config.tie_embeddings:
embedding_layer = discriminator_network.get_embedding_layer()
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg, embedding_layer=embedding_layer)
else:
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg)
return electra_pretrainer.ElectraPretrainer(
generator_network=generator_network,
discriminator_network=discriminator_network,
vocab_size=config.generator_encoder.vocab_size,
num_classes=config.num_classes,
sequence_length=config.sequence_length,
num_token_predictions=config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(
generator_encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=generator_encoder_cfg.initializer_range),
classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads),
disallow_correct=config.disallow_correct)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ELECTRA configurations and models instantiation."""
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import electra
from official.nlp.configs import encoders
class ELECTRAModelsTest(tf.test.TestCase):
def test_network_invocation(self):
config = electra.ELECTRAPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=2),
)
_ = electra.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads.
config = electra.ELECTRAPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=2),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = electra.instantiate_pretrainer_from_cfg(config)
if __name__ == "__main__":
tf.test.main()
...@@ -15,20 +15,23 @@ ...@@ -15,20 +15,23 @@
# ============================================================================== # ==============================================================================
"""Transformer Encoders. """Transformer Encoders.
Includes configurations and instantiation methods. Includes configurations and factory methods.
""" """
from typing import Optional from typing import Optional
from absl import logging
import dataclasses import dataclasses
import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import tf_utils from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
@dataclasses.dataclass @dataclasses.dataclass
class TransformerEncoderConfig(base_config.Config): class BertEncoderConfig(hyperparams.Config):
"""BERT encoder configuration.""" """BERT encoder configuration."""
vocab_size: int = 30522 vocab_size: int = 30522
hidden_size: int = 768 hidden_size: int = 768
...@@ -44,57 +47,86 @@ class TransformerEncoderConfig(base_config.Config): ...@@ -44,57 +47,86 @@ class TransformerEncoderConfig(base_config.Config):
embedding_size: Optional[int] = None embedding_size: Optional[int] = None
def instantiate_encoder_from_cfg( @dataclasses.dataclass
config: TransformerEncoderConfig, class EncoderConfig(hyperparams.OneOfConfig):
encoder_cls=networks.TransformerEncoder, """Encoder configuration."""
embedding_layer: Optional[layers.OnDeviceEmbedding] = None): type: Optional[str] = "bert"
"""Instantiate a Transformer encoder network from TransformerEncoderConfig.""" bert: BertEncoderConfig = BertEncoderConfig()
ENCODER_CLS = {
"bert": networks.TransformerEncoder,
}
@gin.configurable
def build_encoder(config: EncoderConfig,
embedding_layer: Optional[layers.OnDeviceEmbedding] = None,
encoder_cls=None,
bypass_config: bool = False):
"""Instantiate a Transformer encoder network from EncoderConfig.
Args:
config: the one-of encoder config, which provides encoder parameters of a
chosen encoder.
embedding_layer: an external embedding layer passed to the encoder.
encoder_cls: an external encoder cls not included in the supported encoders,
usually used by gin.configurable.
bypass_config: whether to ignore config instance to create the object with
`encoder_cls`.
Returns:
An encoder instance.
"""
encoder_type = config.type
encoder_cfg = config.get()
encoder_cls = encoder_cls or ENCODER_CLS[encoder_type]
logging.info("Encoder class: %s to build...", encoder_cls.__name__)
if bypass_config:
return encoder_cls()
if encoder_cls.__name__ == "EncoderScaffold": if encoder_cls.__name__ == "EncoderScaffold":
embedding_cfg = dict( embedding_cfg = dict(
vocab_size=config.vocab_size, vocab_size=encoder_cfg.vocab_size,
type_vocab_size=config.type_vocab_size, type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=config.hidden_size, hidden_size=encoder_cfg.hidden_size,
seq_length=None, max_seq_length=encoder_cfg.max_position_embeddings,
max_seq_length=config.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range), stddev=encoder_cfg.initializer_range),
dropout_rate=config.dropout_rate, dropout_rate=encoder_cfg.dropout_rate,
) )
hidden_cfg = dict( hidden_cfg = dict(
num_attention_heads=config.num_attention_heads, num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=config.intermediate_size, intermediate_size=encoder_cfg.intermediate_size,
intermediate_activation=tf_utils.get_activation( intermediate_activation=tf_utils.get_activation(
config.hidden_activation), encoder_cfg.hidden_activation),
dropout_rate=config.dropout_rate, dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal( kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range), stddev=encoder_cfg.initializer_range),
) )
kwargs = dict( kwargs = dict(
embedding_cfg=embedding_cfg, embedding_cfg=embedding_cfg,
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
num_hidden_instances=config.num_layers, num_hidden_instances=encoder_cfg.num_layers,
pooled_output_dim=config.hidden_size, pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range)) stddev=encoder_cfg.initializer_range))
return encoder_cls(**kwargs) return encoder_cls(**kwargs)
if encoder_cls.__name__ != "TransformerEncoder": # Uses the default BERTEncoder configuration schema to create the encoder.
raise ValueError("Unknown encoder network class. %s" % str(encoder_cls)) # If it does not match, please add a switch branch by the encoder type.
encoder_network = encoder_cls( return encoder_cls(
vocab_size=config.vocab_size, vocab_size=encoder_cfg.vocab_size,
hidden_size=config.hidden_size, hidden_size=encoder_cfg.hidden_size,
num_layers=config.num_layers, num_layers=encoder_cfg.num_layers,
num_attention_heads=config.num_attention_heads, num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=config.intermediate_size, intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(config.hidden_activation), activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
dropout_rate=config.dropout_rate, dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate,
sequence_length=None, max_sequence_length=encoder_cfg.max_position_embeddings,
max_sequence_length=config.max_position_embeddings, type_vocab_size=encoder_cfg.type_vocab_size,
type_vocab_size=config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range), stddev=encoder_cfg.initializer_range),
embedding_width=config.embedding_size, embedding_width=encoder_cfg.embedding_size,
embedding_layer=embedding_layer) embedding_layer=embedding_layer)
return encoder_network
# Copyright 2017 The TensorFlow Authors All Rights Reserved. # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,38 +12,37 @@ ...@@ -12,38 +12,37 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""An abstraction that NLP models define input pipelines."""
"""Set of blocks related to entropy coding.""" import abc
from typing import Optional
import math
import tensorflow as tf import tensorflow as tf
import block_base
# pylint does not recognize block_base.BlockBase.__call__().
# pylint: disable=not-callable
class CodeLength(block_base.BlockBase): class DataLoader(metaclass=abc.ABCMeta):
"""Theoretical bound for a code length given a probability distribution. """An abstract class defining the APIs for tf.data input pipeline."""
"""
def __init__(self, name=None): @abc.abstractmethod
super(CodeLength, self).__init__(name) def load(
self,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Implements DataLoader load method.
def _Apply(self, c, p): Builds the entire input pipeline inside the load method. Users can define
"""Theoretical bound of the coded length given a probability distribution. states inside the DataLoader class and returns a tf.data dataset
object.
Args: Args:
c: The binary codes. Belong to {0, 1}. input_context: This is a context class that is passed to the user's input
p: The probability of: P(code==+1) function and contains information about the compute replicas and input
pipelines. This object is used for multi-host inputs and passed by
the distribution strategy.
Returns: Returns:
The average code length. A per-host tf.data dataset. Note that, we usually create the distributed
Note: the average code length can be greater than 1 bit (e.g. when dataset through the load method, so we should not directly return a
encoding the least likely symbol). distributed dataset here.
""" """
entropy = ((1.0 - c) * tf.log(1.0 - p) + c * tf.log(p)) / (-math.log(2)) pass
entropy = tf.reduce_mean(entropy)
return entropy
...@@ -21,6 +21,7 @@ import tensorflow as tf ...@@ -21,6 +21,7 @@ import tensorflow as tf
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
...@@ -37,7 +38,7 @@ class BertPretrainDataConfig(cfg.DataConfig): ...@@ -37,7 +38,7 @@ class BertPretrainDataConfig(cfg.DataConfig):
@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig) @data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
class BertPretrainDataLoader: class BertPretrainDataLoader(data_loader.DataLoader):
"""A class to load dataset for bert pretraining task.""" """A class to load dataset for bert pretraining task."""
def __init__(self, params): def __init__(self, params):
......
...@@ -20,6 +20,7 @@ import tensorflow as tf ...@@ -20,6 +20,7 @@ import tensorflow as tf
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
...@@ -42,7 +43,7 @@ class QADataConfig(cfg.DataConfig): ...@@ -42,7 +43,7 @@ class QADataConfig(cfg.DataConfig):
@data_loader_factory.register_data_loader_cls(QADataConfig) @data_loader_factory.register_data_loader_cls(QADataConfig)
class QuestionAnsweringDataLoader: class QuestionAnsweringDataLoader(data_loader.DataLoader):
"""A class to load dataset for sentence prediction (classification) task.""" """A class to load dataset for sentence prediction (classification) task."""
def __init__(self, params): def __init__(self, params):
......
...@@ -20,6 +20,7 @@ import tensorflow as tf ...@@ -20,6 +20,7 @@ import tensorflow as tf
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
...@@ -37,7 +38,7 @@ class SentencePredictionDataConfig(cfg.DataConfig): ...@@ -37,7 +38,7 @@ class SentencePredictionDataConfig(cfg.DataConfig):
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig) @data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
class SentencePredictionDataLoader: class SentencePredictionDataLoader(data_loader.DataLoader):
"""A class to load dataset for sentence prediction (classification) task.""" """A class to load dataset for sentence prediction (classification) task."""
def __init__(self, params): def __init__(self, params):
......
...@@ -23,6 +23,7 @@ from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax ...@@ -23,6 +23,7 @@ from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.multi_channel_attention import * from official.nlp.modeling.layers.multi_channel_attention import *
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import PositionEmbedding from official.nlp.modeling.layers.position_embedding import PositionEmbedding
from official.nlp.modeling.layers.position_embedding import RelativePositionEmbedding
from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention
......
...@@ -521,11 +521,11 @@ class CachedAttention(MultiHeadAttention): ...@@ -521,11 +521,11 @@ class CachedAttention(MultiHeadAttention):
if cache: if cache:
key, value = self._update_cache(key, value, cache, decode_loop_step) key, value = self._update_cache(key, value, cache, decode_loop_step)
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_size)))
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key, query) attention_scores = tf.einsum(self._dot_product_equation, key, query)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, F, T] # `attention_scores` = [B, N, F, T]
......
...@@ -34,7 +34,6 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -34,7 +34,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
def create_layer(self, def create_layer(self,
vocab_size, vocab_size,
sequence_length,
hidden_size, hidden_size,
output='predictions', output='predictions',
xformer_stack=None): xformer_stack=None):
...@@ -44,7 +43,6 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -44,7 +43,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
xformer_stack = transformer_encoder.TransformerEncoder( xformer_stack = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=1, num_layers=1,
sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=4, num_attention_heads=4,
) )
...@@ -62,7 +60,6 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -62,7 +60,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
num_predictions = 21 num_predictions = 21
test_layer = self.create_layer( test_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size) hidden_size=hidden_size)
# Make sure that the output tensor of the masked LM is the right shape. # Make sure that the output tensor of the masked LM is the right shape.
...@@ -81,19 +78,16 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -81,19 +78,16 @@ class MaskedLMTest(keras_parameterized.TestCase):
xformer_stack = transformer_encoder.TransformerEncoder( xformer_stack = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=1, num_layers=1,
sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=4, num_attention_heads=4,
) )
test_layer = self.create_layer( test_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size,
xformer_stack=xformer_stack, xformer_stack=xformer_stack,
output='predictions') output='predictions')
logit_layer = self.create_layer( logit_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size,
xformer_stack=xformer_stack, xformer_stack=xformer_stack,
output='logits') output='logits')
...@@ -134,7 +128,6 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -134,7 +128,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
num_predictions = 21 num_predictions = 21
test_layer = self.create_layer( test_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size) hidden_size=hidden_size)
# Create a model from the masked LM layer. # Create a model from the masked LM layer.
...@@ -155,7 +148,7 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -155,7 +148,7 @@ class MaskedLMTest(keras_parameterized.TestCase):
def test_unknown_output_type_fails(self): def test_unknown_output_type_fails(self):
with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'): with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'):
_ = self.create_layer( _ = self.create_layer(
vocab_size=8, sequence_length=8, hidden_size=8, output='bad') vocab_size=8, hidden_size=8, output='bad')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -38,6 +38,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -38,6 +38,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
lookup. Defaults to False (that is, using tf.gather). Setting this option lookup. Defaults to False (that is, using tf.gather). Setting this option
to True may improve performance, especially on small vocabulary sizes, but to True may improve performance, especially on small vocabulary sizes, but
will generally require more memory. will generally require more memory.
use_scale: Whether to scale the output embeddings. Defaults to False (that
is, not to scale). Setting this option to True will let values in output
embeddings multiplied by self._embedding_width ** 0.5.
""" """
def __init__(self, def __init__(self,
...@@ -45,6 +48,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -45,6 +48,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
embedding_width, embedding_width,
initializer="glorot_uniform", initializer="glorot_uniform",
use_one_hot=False, use_one_hot=False,
use_scale=False,
**kwargs): **kwargs):
super(OnDeviceEmbedding, self).__init__(**kwargs) super(OnDeviceEmbedding, self).__init__(**kwargs)
...@@ -52,6 +56,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -52,6 +56,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
self._embedding_width = embedding_width self._embedding_width = embedding_width
self._initializer = initializer self._initializer = initializer
self._use_one_hot = use_one_hot self._use_one_hot = use_one_hot
self._use_scale = use_scale
def get_config(self): def get_config(self):
config = { config = {
...@@ -59,6 +64,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -59,6 +64,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
"embedding_width": self._embedding_width, "embedding_width": self._embedding_width,
"initializer": self._initializer, "initializer": self._initializer,
"use_one_hot": self._use_one_hot, "use_one_hot": self._use_one_hot,
"use_scale": self._use_scale,
} }
base_config = super(OnDeviceEmbedding, self).get_config() base_config = super(OnDeviceEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -85,4 +91,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -85,4 +91,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
# Work around b/142213824: prefer concat to shape over a Python list. # Work around b/142213824: prefer concat to shape over a Python list.
tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0)) tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0))
embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width]) embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width])
if self._use_scale:
embeddings *= self._embedding_width ** 0.5
return embeddings return embeddings
...@@ -193,6 +193,26 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase): ...@@ -193,6 +193,26 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
output = model.predict(input_data) output = model.predict(input_data)
self.assertEqual(tf.float16, output.dtype) self.assertEqual(tf.float16, output.dtype)
def test_use_scale_layer_invocation(self):
vocab_size = 31
embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width, use_scale=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# Create a model from the test layer.
model = tf.keras.Model(input_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 3
input_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
output = model.predict(input_data)
self.assertEqual(tf.float32, output.dtype)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -49,6 +49,13 @@ class Transformer(tf.keras.layers.Layer): ...@@ -49,6 +49,13 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity. activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels. kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
""" """
def __init__(self, def __init__(self,
...@@ -65,6 +72,10 @@ class Transformer(tf.keras.layers.Layer): ...@@ -65,6 +72,10 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer=None, activity_regularizer=None,
kernel_constraint=None, kernel_constraint=None,
bias_constraint=None, bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
intermediate_dropout=0.0,
**kwargs): **kwargs):
super(Transformer, self).__init__(**kwargs) super(Transformer, self).__init__(**kwargs)
...@@ -81,6 +92,10 @@ class Transformer(tf.keras.layers.Layer): ...@@ -81,6 +92,10 @@ class Transformer(tf.keras.layers.Layer):
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer) self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout
def build(self, input_shape): def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
...@@ -117,14 +132,9 @@ class Transformer(tf.keras.layers.Layer): ...@@ -117,14 +132,9 @@ class Transformer(tf.keras.layers.Layer):
num_heads=self._num_heads, num_heads=self._num_heads,
key_size=self._attention_head_size, key_size=self._attention_head_size,
dropout=self._attention_dropout_rate, dropout=self._attention_dropout_rate,
use_bias=self._use_bias,
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
# pylint: disable=protected-access
# Temporarily handling for checkpoint compatible changes.
self._attention_layer._build_from_signature(
query=input_tensor_shape, value=input_tensor_shape)
self._attention_output_dense = self._attention_layer._output_dense
# pylint: enable=protected-access
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet. # It is probably safe in mixed_float16, but we haven't validated this yet.
...@@ -132,7 +142,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -132,7 +142,7 @@ class Transformer(tf.keras.layers.Layer):
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", name="self_attention_layer_norm",
axis=-1, axis=-1,
epsilon=1e-12, epsilon=self._norm_epsilon,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
...@@ -148,6 +158,8 @@ class Transformer(tf.keras.layers.Layer): ...@@ -148,6 +158,8 @@ class Transformer(tf.keras.layers.Layer):
policy = tf.float32 policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation( self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy) self._intermediate_activation, dtype=policy)
self._intermediate_dropout_layer = tf.keras.layers.Dropout(
rate=self._intermediate_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
...@@ -157,7 +169,10 @@ class Transformer(tf.keras.layers.Layer): ...@@ -157,7 +169,10 @@ class Transformer(tf.keras.layers.Layer):
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization( self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32) name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
super(Transformer, self).build(input_shape) super(Transformer, self).build(input_shape)
...@@ -188,7 +203,15 @@ class Transformer(tf.keras.layers.Layer): ...@@ -188,7 +203,15 @@ class Transformer(tf.keras.layers.Layer):
"kernel_constraint": "kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint), tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint": "bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint) tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"intermediate_dropout":
self._intermediate_dropout
} }
base_config = super(Transformer, self).get_config() base_config = super(Transformer, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -203,23 +226,36 @@ class Transformer(tf.keras.layers.Layer): ...@@ -203,23 +226,36 @@ class Transformer(tf.keras.layers.Layer):
target_tensor = input_tensor[:, 0:self._output_range, :] target_tensor = input_tensor[:, 0:self._output_range, :]
attention_mask = attention_mask[:, 0:self._output_range, :] attention_mask = attention_mask[:, 0:self._output_range, :]
else: else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
target_tensor = input_tensor target_tensor = input_tensor
attention_output = self._attention_layer( attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask) query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(target_tensor + if self._norm_first:
attention_output) attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(target_tensor +
attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
intermediate_output = self._intermediate_dense(attention_output) intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer( intermediate_output = self._intermediate_activation_layer(
intermediate_output) intermediate_output)
intermediate_output = self._intermediate_dropout_layer(intermediate_output)
layer_output = self._output_dense(intermediate_output) layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output) layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and # During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. Cast layer_output to fp32 for the subsequent # is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add. # add.
layer_output = tf.cast(layer_output, tf.float32) layer_output = tf.cast(layer_output, tf.float32)
layer_output = self._output_layer_norm(layer_output + attention_output) if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output)
return layer_output return layer_output
...@@ -257,6 +293,13 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -257,6 +293,13 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity. activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels. kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
""" """
def __init__(self, def __init__(self,
...@@ -273,6 +316,10 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -273,6 +316,10 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
activity_regularizer=None, activity_regularizer=None,
kernel_constraint=None, kernel_constraint=None,
bias_constraint=None, bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
intermediate_dropout=0.0,
**kwargs): **kwargs):
super(TransformerDecoderLayer, self).__init__(**kwargs) super(TransformerDecoderLayer, self).__init__(**kwargs)
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
...@@ -289,6 +336,10 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -289,6 +336,10 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer) self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout
if self.multi_channel_cross_attention: if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else: else:
...@@ -318,6 +369,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -318,6 +369,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
key_size=self.attention_head_size, key_size=self.attention_head_size,
dropout=self.attention_dropout_rate, dropout=self.attention_dropout_rate,
use_bias=self._use_bias,
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense( self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense(
...@@ -330,13 +382,16 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -330,13 +382,16 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
rate=self.dropout_rate) rate=self.dropout_rate)
self.self_attention_layer_norm = ( self.self_attention_layer_norm = (
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", axis=-1, epsilon=1e-12)) name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon))
# Encoder-decoder attention. # Encoder-decoder attention.
self.encdec_attention = self._cross_attention_cls( self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
key_size=self.attention_head_size, key_size=self.attention_head_size,
dropout=self.attention_dropout_rate, dropout=self.attention_dropout_rate,
output_shape=hidden_size, output_shape=hidden_size,
use_bias=self._use_bias,
name="attention/encdec", name="attention/encdec",
**common_kwargs) **common_kwargs)
...@@ -344,7 +399,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -344,7 +399,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
rate=self.dropout_rate) rate=self.dropout_rate)
self.encdec_attention_layer_norm = ( self.encdec_attention_layer_norm = (
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12)) name="attention/encdec_output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon))
# Feed-forward projection. # Feed-forward projection.
self.intermediate_dense = tf.keras.layers.experimental.EinsumDense( self.intermediate_dense = tf.keras.layers.experimental.EinsumDense(
...@@ -355,6 +412,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -355,6 +412,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
**common_kwargs) **common_kwargs)
self.intermediate_activation_layer = tf.keras.layers.Activation( self.intermediate_activation_layer = tf.keras.layers.Activation(
self.intermediate_activation) self.intermediate_activation)
self._intermediate_dropout_layer = tf.keras.layers.Dropout(
rate=self._intermediate_dropout)
self.output_dense = tf.keras.layers.experimental.EinsumDense( self.output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
...@@ -363,9 +422,49 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -363,9 +422,49 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
**common_kwargs) **common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization( self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12) name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon)
super(TransformerDecoderLayer, self).build(input_shape) super(TransformerDecoderLayer, self).build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self.num_attention_heads,
"intermediate_size":
self.intermediate_size,
"intermediate_activation":
self.intermediate_activation,
"dropout_rate":
self.dropout_rate,
"attention_dropout_rate":
self.attention_dropout_rate,
"multi_channel_cross_attention":
self.multi_channel_cross_attention,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"intermediate_dropout":
self._intermediate_dropout
}
base_config = super(TransformerDecoderLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def common_layers_with_encoder(self): def common_layers_with_encoder(self):
"""Gets layer objects that can make a Transformer encoder block.""" """Gets layer objects that can make a Transformer encoder block."""
return [ return [
...@@ -384,6 +483,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -384,6 +483,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"TransformerDecoderLayer must have 4 inputs, but it got: %d" % "TransformerDecoderLayer must have 4 inputs, but it got: %d" %
len(inputs)) len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4] input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
source_tensor = input_tensor
if self._norm_first:
input_tensor = self.self_attention_layer_norm(input_tensor)
self_attention_output, cache = self.self_attention( self_attention_output, cache = self.self_attention(
query=input_tensor, query=input_tensor,
value=input_tensor, value=input_tensor,
...@@ -391,8 +493,15 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -391,8 +493,15 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
cache=cache, cache=cache,
decode_loop_step=decode_loop_step) decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_dropout(self_attention_output) self_attention_output = self.self_attention_dropout(self_attention_output)
self_attention_output = self.self_attention_layer_norm( if self._norm_first:
input_tensor + self_attention_output) self_attention_output = source_tensor + self_attention_output
else:
self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output)
if self._norm_first:
source_self_attention_output = self_attention_output
self_attention_output = self.encdec_attention_layer_norm(
self_attention_output)
cross_attn_inputs = dict( cross_attn_inputs = dict(
query=self_attention_output, query=self_attention_output,
value=memory, value=memory,
...@@ -402,13 +511,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -402,13 +511,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
cross_attn_inputs["context_attention_weights"] = inputs[-1] cross_attn_inputs["context_attention_weights"] = inputs[-1]
attention_output = self.encdec_attention(**cross_attn_inputs) attention_output = self.encdec_attention(**cross_attn_inputs)
attention_output = self.encdec_attention_dropout(attention_output) attention_output = self.encdec_attention_dropout(attention_output)
attention_output = self.encdec_attention_layer_norm(self_attention_output + if self._norm_first:
attention_output) attention_output = source_self_attention_output + attention_output
else:
attention_output = self.encdec_attention_layer_norm(
self_attention_output + attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self.output_layer_norm(attention_output)
intermediate_output = self.intermediate_dense(attention_output) intermediate_output = self.intermediate_dense(attention_output)
intermediate_output = self.intermediate_activation_layer( intermediate_output = self.intermediate_activation_layer(
intermediate_output) intermediate_output)
intermediate_output = self._intermediate_dropout_layer(intermediate_output)
layer_output = self.output_dense(intermediate_output) layer_output = self.output_dense(intermediate_output)
layer_output = self.output_dropout(layer_output) layer_output = self.output_dropout(layer_output)
layer_output = self.output_layer_norm(layer_output + attention_output) if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self.output_layer_norm(layer_output + attention_output)
return layer_output, cache return layer_output, cache
...@@ -152,10 +152,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -152,10 +152,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
_ = new_layer([input_data, mask_data]) _ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights()) new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data]) new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(new_output_tensor, self.assertAllClose(
output_tensor[:, 0:1, :], new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
atol=5e-5,
rtol=0.003)
def test_layer_invocation_with_float16_dtype(self, transformer_cls): def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16') tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
...@@ -218,6 +216,47 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -218,6 +216,47 @@ class TransformerLayerTest(keras_parameterized.TestCase):
self.assertAllEqual([1, input_length, width], output_data.shape) self.assertAllEqual([1, input_length, width], output_data.shape)
@keras_parameterized.run_all_keras_modes
class TransformerArgumentTest(keras_parameterized.TestCase):
def test_use_bias_norm_first(self):
num_attention_heads = 2
hidden_size = 16
encoder_block = transformer.Transformer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.1)
# Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
inputs = [dummy_tensor, dummy_mask]
output = encoder_block(inputs)
self.assertEqual(output.shape, (2, 4, hidden_size))
def test_get_config(self):
num_attention_heads = 2
encoder_block = transformer.Transformer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.1)
encoder_block_config = encoder_block.get_config()
new_encoder_block = transformer.Transformer.from_config(
encoder_block_config)
self.assertEqual(encoder_block_config, new_encoder_block.get_config())
def _create_cache(batch_size, init_decode_length, num_heads, head_size): def _create_cache(batch_size, init_decode_length, num_heads, head_size):
return { return {
'key': 'key':
...@@ -251,6 +290,43 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase): ...@@ -251,6 +290,43 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
self.assertEqual(output.shape, (2, 4, hidden_size)) self.assertEqual(output.shape, (2, 4, hidden_size))
self.assertEqual(cache['value'].shape, (2, 4, 2, 8)) self.assertEqual(cache['value'].shape, (2, 4, 2, 8))
def test_use_bias_norm_first(self):
num_attention_heads = 2
hidden_size = 16
decoder_block = transformer.TransformerDecoderLayer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.1)
# Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
inputs = [dummy_tensor, dummy_tensor, dummy_mask, dummy_mask]
output, _ = decoder_block(inputs)
self.assertEqual(output.shape, (2, 4, hidden_size))
def test_get_config(self):
num_attention_heads = 2
decoder_block = transformer.TransformerDecoderLayer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.1)
decoder_block_config = decoder_block.get_config()
new_decoder_block = transformer.TransformerDecoderLayer.from_config(
decoder_block_config)
self.assertEqual(decoder_block_config, new_decoder_block.get_config())
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -10,8 +10,8 @@ model containing a single classification head using the Classification network. ...@@ -10,8 +10,8 @@ model containing a single classification head using the Classification network.
It can be used as a regression model as well. It can be used as a regression model as well.
* [`BertTokenClassifier`](bert_token_classifier.py) implements a simple token * [`BertTokenClassifier`](bert_token_classifier.py) implements a simple token
classification model containing a single classification head using the classification model containing a single classification head over the sequence
TokenClassification network. output embeddings.
* [`BertSpanLabeler`](bert_span_labeler.py) implementats a simple single-span * [`BertSpanLabeler`](bert_span_labeler.py) implementats a simple single-span
start-end predictor (that is, a model that predicts two values: a start token start-end predictor (that is, a model that predicts two values: a start token
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================== # ==============================================================================
"""Models package definition.""" """Models package definition."""
from official.nlp.modeling.models.bert_classifier import BertClassifier from official.nlp.modeling.models.bert_classifier import BertClassifier
from official.nlp.modeling.models.bert_pretrainer import BertPretrainer from official.nlp.modeling.models.bert_pretrainer import *
from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler
from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier
from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer
...@@ -12,12 +12,8 @@ ...@@ -12,12 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Trainer network for BERT-style models.""" """BERT cls-token classifier."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf import tensorflow as tf
......
...@@ -38,7 +38,7 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -38,7 +38,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size, num_layers=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
...@@ -62,7 +62,7 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -62,7 +62,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=2) vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
...@@ -83,7 +83,7 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -83,7 +83,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=5) vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
......
...@@ -12,12 +12,8 @@ ...@@ -12,12 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Trainer network for BERT-style models.""" """BERT Pre-training model."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import copy import copy
from typing import List, Optional from typing import List, Optional
...@@ -98,7 +94,8 @@ class BertPretrainer(tf.keras.Model): ...@@ -98,7 +94,8 @@ class BertPretrainer(tf.keras.Model):
if isinstance(cls_output, list): if isinstance(cls_output, list):
cls_output = cls_output[-1] cls_output = cls_output[-1]
sequence_output_length = sequence_output.shape.as_list()[1] sequence_output_length = sequence_output.shape.as_list()[1]
if sequence_output_length < num_token_predictions: if sequence_output_length is not None and (sequence_output_length <
num_token_predictions):
raise ValueError( raise ValueError(
"The passed network's output length is %s, which is less than the " "The passed network's output length is %s, which is less than the "
'requested num_token_predictions %s.' % 'requested num_token_predictions %s.' %
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment