Unverified Commit aef943ed authored by SunJong Park's avatar SunJong Park Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into master

parents 67ad909d 930abe21
......@@ -35,22 +35,16 @@ To install the current release of tensorflow-models, please follow any one of th
**tf-models-official** is the stable Model Garden package. Please check out the [releases](https://github.com/tensorflow/models/releases) to see what are available modules.
pip will install all models and dependencies automatically.
pip3 will install all models and dependencies automatically.
```shell
pip3 install tf-models-official
```
If you are using nlp packages, please also install **tensorflow-text**:
```shell
pip3 install tensorflow-text
```
Please check out our [example](https://github.com/tensorflow/text/blob/master/docs/tutorials/fine_tune_bert.ipynb)
to learn how to use a PIP package.
Note that **tf-models-official** may not include the latest changes in this
Note that **tf-models-official** may not include the latest changes in the master branch of this
github repo. To include latest changes, you may install **tf-models-nightly**,
which is the nightly Model Garden package created daily automatically.
......@@ -58,11 +52,6 @@ which is the nightly Model Garden package created daily automatically.
pip3 install tf-models-nightly
```
If you are using `nlp` packages, please also install tensorflow-text-nightly
```shell
pip3 install tensorflow-text-nightly
```
</details>
......
......@@ -66,12 +66,32 @@ In the near future, we will add:
### [Natural Language Processing](nlp/README.md)
#### Pre-trained Language Model
| Model | Reference (Paper) |
|-------|-------------------|
| [ALBERT](nlp/MODEL_GARDEN.md#available-model-configs) | [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) |
| [BERT](nlp/MODEL_GARDEN.md#available-model-configs) | [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) |
| [ELECTRA](nlp/tasks/electra_task.py) | [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://arxiv.org/abs/2003.10555) |
#### Neural Machine Translation
| Model | Reference (Paper) |
|-------|-------------------|
| [ALBERT (A Lite BERT)](nlp/MODEL_GARDEN.md#available-model-configs) | [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) |
| [BERT (Bidirectional Encoder Representations from Transformers)](nlp/MODEL_GARDEN.md#available-model-configs) | [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) |
| [NHNet (News Headline generation model)](projects/nhnet) | [Generating Representative Headlines for News Stories](https://arxiv.org/abs/2001.09386) |
| [Transformer](nlp/MODEL_GARDEN.md#available-model-configs) | [Attention Is All You Need](https://arxiv.org/abs/1706.03762) |
#### Natural Language Generation
| Model | Reference (Paper) |
|-------|-------------------|
| [NHNet (News Headline generation model)](projects/nhnet) | [Generating Representative Headlines for News Stories](https://arxiv.org/abs/2001.09386) |
#### Knowledge Distillation
| Model | Reference (Paper) |
|-------|-------------------|
| [MobileBERT](projects/mobilebert) | [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) |
### Recommendation
......
......@@ -15,7 +15,7 @@
"""TFM common training driver library."""
# pytype: disable=attribute-error
import os
from typing import Any, Mapping, Optional, Tuple
from typing import Any, Mapping, Optional, Tuple, List
# Import libraries
......@@ -40,6 +40,8 @@ def run_experiment(
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
trainer: Optional[base_trainer.Trainer] = None,
controller_cls=orbit.Controller
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
......@@ -55,6 +57,8 @@ def run_experiment(
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
train_actions: Optional list of Orbit train actions.
eval_actions: Optional list of Orbit eval actions.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().
controller_cls: The controller class to manage the train and eval process.
......@@ -90,6 +94,13 @@ def run_experiment(
else:
checkpoint_manager = None
train_actions = [] if not train_actions else train_actions
train_actions += actions.get_train_actions(
params, trainer, model_dir, checkpoint_manager=checkpoint_manager)
eval_actions = [] if not eval_actions else eval_actions
eval_actions += actions.get_eval_actions(params, trainer, model_dir)
controller = controller_cls(
strategy=distribution_strategy,
trainer=trainer if 'train' in mode else None,
......@@ -103,9 +114,8 @@ def run_experiment(
(save_summary) else None,
summary_interval=params.trainer.summary_interval if
(save_summary) else None,
train_actions=actions.get_train_actions(
params, trainer, model_dir, checkpoint_manager=checkpoint_manager),
eval_actions=actions.get_eval_actions(params, trainer, model_dir))
train_actions=train_actions,
eval_actions=eval_actions)
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
......
......@@ -43,8 +43,8 @@ def export_tfhub(model_path, hub_destination, model_name):
image_input = tf.keras.layers.Input(
shape=(None, None, 3), name="image_input", dtype=tf.float32)
x = image_input * 255.0
ouputs = efficientnet_model.efficientnet(x, config)
hub_model = tf.keras.Model(image_input, ouputs)
outputs = efficientnet_model.efficientnet(x, config)
hub_model = tf.keras.Model(image_input, outputs)
ckpt = tf.train.Checkpoint(model=hub_model)
ckpt.restore(model_path).assert_existing_objects_matched()
hub_model.save(
......
......@@ -17,6 +17,8 @@ import math
import tensorflow as tf
from official.modeling import tf_utils
class Attention(tf.keras.layers.Layer):
"""Multi-headed attention layer."""
......@@ -53,19 +55,19 @@ class Attention(tf.keras.layers.Layer):
self.query_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=attention_initializer,
kernel_initializer=tf_utils.clone_initializer(attention_initializer),
bias_axes=None,
name="query")
self.key_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=attention_initializer,
kernel_initializer=tf_utils.clone_initializer(attention_initializer),
bias_axes=None,
name="key")
self.value_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=attention_initializer,
kernel_initializer=tf_utils.clone_initializer(attention_initializer),
bias_axes=None,
name="value")
......
......@@ -188,7 +188,7 @@ def download_and_extract(path, url, input_filename, target_filename):
Full paths to extracted input and target files.
Raises:
OSError: if the the download/extraction fails.
OSError: if the download/extraction fails.
"""
# Check if extracted files already exist in path
input_file = find_file(path, input_filename)
......
......@@ -41,15 +41,15 @@ _PARAM_RE = re.compile(
_CONST_VALUE_RE = re.compile(r'(\d.*|-\d.*|None)')
# Yaml loader with an implicit resolver to parse float decimal and exponential
# Yaml LOADER with an implicit resolver to parse float decimal and exponential
# format. The regular experission parse the following cases:
# 1- Decimal number with an optional exponential term.
# 2- Integer number with an exponential term.
# 3- Decimal number with an optional exponential term.
# 4- Decimal number.
LOADER = yaml.SafeLoader
LOADER.add_implicit_resolver(
_LOADER = yaml.SafeLoader
_LOADER.add_implicit_resolver(
'tag:yaml.org,2002:float',
re.compile(r'''
^(?:[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
......@@ -332,7 +332,7 @@ class ParamsDict(object):
def read_yaml_to_params_dict(file_path: str):
"""Reads a YAML file to a ParamsDict."""
with tf.io.gfile.GFile(file_path, 'r') as f:
params_dict = yaml.load(f, Loader=LOADER)
params_dict = yaml.load(f, Loader=_LOADER)
return ParamsDict(params_dict)
......@@ -453,7 +453,7 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
nested_csv_str_to_json_str(dict_or_string_or_yaml_file))
except ValueError:
pass
params_dict = yaml.load(dict_or_string_or_yaml_file, Loader=LOADER)
params_dict = yaml.load(dict_or_string_or_yaml_file, Loader=_LOADER)
if isinstance(params_dict, dict):
params.override(params_dict, is_strict)
else:
......
......@@ -23,7 +23,7 @@ import tensorflow as tf
def _make_offset_wrapper(new_class_name: str, base_lr_class):
"""Generates a offset wrapper of learning rate schedule.
It will returns a subclass of the the `base_lr_class`, the subclass takes an
It will returns a subclass of the `base_lr_class`, the subclass takes an
`offset` argument in the constructor. When the new class instance is called,
the behavior is:
new_class_object(step) = base_lr_class_object(step - offset)
......
......@@ -272,3 +272,14 @@ def cross_replica_concat(value, axis, name="cross_replica_concat"):
if value.shape.as_list()[0] is None:
raise RuntimeError(f"{value} has unknown batch.")
return context.all_gather(value, axis=axis)
def clone_initializer(initializer):
# Keras initializer is going to be stateless, which mean reusing the same
# initializer will produce same init value when the shapes are the same.
if isinstance(initializer, tf.keras.initializers.Initializer):
return initializer.__class__.from_config(initializer.get_config())
# When the input is string/dict or other serialized configs, caller will
# create a new keras Initializer instance based on that, and we don't need to
# do anything
return initializer
......@@ -36,7 +36,7 @@ from sentencepiece import SentencePieceTrainer
FLAGS = flags.FLAGS
flags.DEFINE_string("output_model_path", None,
"Path to save the the sentencepiece model.")
"Path to save the sentencepiece model.")
flags.mark_flag_as_required("output_model_path")
flags.DEFINE_string("tfds_dir", None, "Directory of the tfds.")
......
......@@ -105,7 +105,7 @@ pip3 install --user -r official/requirements.txt
<details>
This example fine-tunes BERT-base from TF-Hub on the the Multi-Genre Natural
This example fine-tunes BERT-base from TF-Hub on the Multi-Genre Natural
Language Inference (MultiNLI) corpus using TPUs.
Firstly, you can prepare the fine-tuning data using
......
......@@ -13,7 +13,7 @@ assemble new `tf.keras` layers or models.
["Big Bird: Transformers for Longer Sequences"](https://arxiv.org/abs/2007.14062).
* [CachedAttention](attention.py) implements an attention layer with cache
used for auto-agressive decoding.
used for auto-aggressive decoding.
* [KernelAttention](kernel_attention.py) implements a group of attention
mechansim that express the self-attention as a linear dot-product of
......
......@@ -18,6 +18,8 @@ from typing import Optional
import tensorflow as tf
from official.modeling import tf_utils
class BlockDiagFeedforward(tf.keras.layers.Layer):
"""Block diagonal feedforward layer.
......@@ -80,8 +82,6 @@ class BlockDiagFeedforward(tf.keras.layers.Layer):
hidden_size = input_shape.as_list()[-1]
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -94,6 +94,8 @@ class BlockDiagFeedforward(tf.keras.layers.Layer):
self._intermediate_size // self._num_blocks),
bias_axes="de",
name="intermediate",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
......@@ -110,6 +112,8 @@ class BlockDiagFeedforward(tf.keras.layers.Layer):
hidden_size // self._num_blocks),
bias_axes="do",
name="output",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
if self._apply_mixing:
......@@ -118,6 +122,9 @@ class BlockDiagFeedforward(tf.keras.layers.Layer):
output_shape=(None, self._num_blocks,
hidden_size // self._num_blocks),
name="output_mixing",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._output_reshape = tf.keras.layers.Reshape((-1, hidden_size))
......
......@@ -57,12 +57,14 @@ class ClassificationHead(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense(
units=self.inner_dim,
activation=self.activation,
kernel_initializer=self.initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name="pooler_dense")
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.out_proj = tf.keras.layers.Dense(
units=num_classes, kernel_initializer=self.initializer, name="logits")
units=num_classes,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name="logits")
def call(self, features: tf.Tensor, only_project: bool = False):
"""Implements call().
......@@ -146,14 +148,15 @@ class MultiClsHeads(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense(
units=inner_dim,
activation=self.activation,
kernel_initializer=self.initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name="pooler_dense")
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.out_projs = []
for name, num_classes in cls_list:
self.out_projs.append(
tf.keras.layers.Dense(
units=num_classes, kernel_initializer=self.initializer,
units=num_classes,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name=name))
def call(self, features: tf.Tensor, only_project: bool = False):
......@@ -277,7 +280,7 @@ class GaussianProcessClassificationHead(ClassificationHead):
if use_gp_layer:
self.out_proj = gaussian_process.RandomFeatureGaussianProcess(
self.num_classes,
kernel_initializer=self.initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name="logits",
**self.gp_layer_kwargs)
......@@ -361,3 +364,97 @@ def extract_spec_norm_kwargs(kwargs):
return dict(
iteration=kwargs.pop("iteration", 1),
norm_multiplier=kwargs.pop("norm_multiplier", .99))
class PerQueryDenseHead(tf.keras.layers.Layer):
"""Pooling head used for EncT5 style models.
This module projects each query to use a different projection.
For a input shape= [bs, num_queries, hidden_size], it projects each query to
(features). Ending up with shape= [bs, num_queries, features].
For example, for classification with a few classes, one may use num_queries
as 1 and features as number of classes. For multilabel classification, one
may use num_queries as number of classes and features as 2. So each query
represents a binary classification of one label.
"""
def __init__(self,
num_queries: int,
features: int,
use_bias: bool = False,
kernel_initializer: str = "glorot_uniform",
**kwargs):
"""Initializes the `PerQueryDenseHead`.
Args:
num_queries: number of queries (the learnable embeddings in the input
sequences) from the decoder.
features: int with numbers of output features. Each query with be
projected to this number with a different projection.
use_bias: whether to add a bias to the output.
kernel_initializer: Initializer for dense layer kernels.
**kwargs: Keyword arguments.
"""
super().__init__(**kwargs)
self.num_queries = num_queries
self.features = features
self.use_bias = use_bias
self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
def build(self, input_shape):
input_shape = tf.TensorShape(input_shape)
# Hidden size.
last_dim = tf.compat.dimension_value(input_shape[-1])
self.hidden_size = last_dim
self.kernel = self.add_weight(
"kernel",
shape=[self.num_queries, last_dim, self.features],
initializer=self.kernel_initializer,
dtype=self.dtype,
trainable=True)
if self.use_bias:
self.bias = self.add_weight(
"bias",
shape=[
self.num_queries,
self.features,
],
dtype=self.dtype,
trainable=True)
else:
self.bias = None
def call(self, inputs: tf.Tensor) -> tf.Tensor:
"""Implements call().
Args:
inputs: a rank-3 Tensor of shape= [bs, num_queries, hidden_size].
Returns:
A Tensor, shape= [batch size, num_queries, features].
"""
outputs = tf.einsum("bqh,qhf->bqf", inputs, self.kernel)
if self.use_bias:
outputs += self.bias
return outputs
def get_config(self):
config = {
"num_queries":
self.num_queries,
"features":
self.features,
"kernel_initializer":
tf.keras.activations.serialize(self.kernel_initializer),
}
config.update(super(PerQueryDenseHead, self).get_config())
return config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
......@@ -199,5 +199,29 @@ class GaussianProcessClassificationHead(tf.test.TestCase,
self.assertEqual(layer_config["norm_multiplier"], 1.)
self.assertEqual(layer_config["num_inducing"], 512)
class PerQueryDenseHeadTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("single_query", 1, 3, False),
("multi_queries", 10, 2, False),
("with_bias", 10, 2, True))
def test_layer_invocation(self, num_queries, features, use_bias):
batch_size = 5
hidden_size = 10
layer = cls_head.PerQueryDenseHead(
num_queries=num_queries, features=features, use_bias=use_bias)
inputs = tf.zeros(
shape=(batch_size, num_queries, hidden_size), dtype=tf.float32)
outputs = layer(inputs)
self.assertEqual(outputs.shape, [batch_size, num_queries, features])
def test_layer_serialization(self):
layer = cls_head.PerQueryDenseHead(
num_queries=10, features=2, use_bias=True)
new_layer = cls_head.PerQueryDenseHead.from_config(layer.get_config())
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(layer.get_config(), new_layer.get_config())
if __name__ == "__main__":
tf.test.main()
......@@ -18,6 +18,8 @@
import gin
import tensorflow as tf
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable
......@@ -95,8 +97,6 @@ class GatedFeedforward(tf.keras.layers.Layer):
hidden_size = input_shape.as_list()[-1]
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -121,6 +121,10 @@ class GatedFeedforward(tf.keras.layers.Layer):
output_shape=(None, self._intermediate_size),
bias_axes="d",
name="intermediate_%d" % i,
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs))
self._intermediate_activation_layers.append(
tf.keras.layers.Activation(
......@@ -132,6 +136,10 @@ class GatedFeedforward(tf.keras.layers.Layer):
output_shape=(None, self._intermediate_size),
bias_axes="d",
name="gate_%d" % i,
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs))
self._output_dense.append(
tf.keras.layers.experimental.EinsumDense(
......@@ -139,6 +147,10 @@ class GatedFeedforward(tf.keras.layers.Layer):
output_shape=(None, hidden_size),
bias_axes="d",
name="output_%d" % i,
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs))
self._output_dropout.append(tf.keras.layers.Dropout(rate=self._dropout))
# Use float32 in layernorm for numeric stability.
......
......@@ -226,7 +226,7 @@ class RandomFeatureGaussianProcess(tf.keras.layers.Layer):
"""Resets covariance matrix of the GP layer.
This function is useful for reseting the model's covariance matrix at the
begining of a new epoch.
beginning of a new epoch.
"""
self._gp_cov_layer.reset_precision_matrix()
......@@ -380,7 +380,7 @@ class LaplaceRandomFeatureCovariance(tf.keras.layers.Layer):
"""Resets precision matrix to its initial value.
This function is useful for reseting the model's covariance matrix at the
begining of a new epoch.
beginning of a new epoch.
"""
precision_matrix_reset_op = self.precision_matrix.assign(
self.initial_precision_matrix)
......
......@@ -15,6 +15,8 @@
"""MobileBERT embedding and transformer layers."""
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import on_device_embedding
from official.nlp.modeling.layers import position_embedding
......@@ -109,21 +111,21 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
self.word_embedding = on_device_embedding.OnDeviceEmbedding(
self.word_vocab_size,
self.word_embed_size,
initializer=initializer,
initializer=tf_utils.clone_initializer(self.initializer),
name='word_embedding')
self.type_embedding = on_device_embedding.OnDeviceEmbedding(
self.type_vocab_size,
self.output_embed_size,
initializer=initializer,
initializer=tf_utils.clone_initializer(self.initializer),
name='type_embedding')
self.pos_embedding = position_embedding.PositionEmbedding(
max_length=max_sequence_length,
initializer=initializer,
initializer=tf_utils.clone_initializer(self.initializer),
name='position_embedding')
self.word_embedding_proj = tf.keras.layers.experimental.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.output_embed_size],
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
bias_axes='d',
name='embedding_projection')
self.layer_norm = _get_norm_layer(normalization_type, 'embedding_norm')
......@@ -246,7 +248,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size],
bias_axes='d',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='bottleneck_input/dense')
layer_norm = _get_norm_layer(self.normalization_type,
name='bottleneck_input/norm')
......@@ -258,7 +260,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size],
bias_axes='d',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='kq_shared_bottleneck/dense')
layer_norm = _get_norm_layer(self.normalization_type,
name='kq_shared_bottleneck/norm')
......@@ -272,7 +274,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
value_dim=attention_head_size,
dropout=self.attention_probs_dropout_prob,
output_shape=self.intra_bottleneck_size,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='attention')
layer_norm = _get_norm_layer(self.normalization_type,
name='attention/norm')
......@@ -289,14 +291,14 @@ class MobileBertTransformer(tf.keras.layers.Layer):
activation=self.intermediate_act_fn,
output_shape=[None, self.intermediate_size],
bias_axes='d',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name=layer_name)
layer_name = layer_prefix + '/output_dense'
output_layer = tf.keras.layers.experimental.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size],
bias_axes='d',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name=layer_name)
layer_name = layer_prefix + '/norm'
layer_norm = _get_norm_layer(self.normalization_type,
......@@ -311,7 +313,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
output_shape=[None, self.hidden_size],
activation=None,
bias_axes='d',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='bottleneck_output/dense')
dropout_layer = tf.keras.layers.Dropout(
self.hidden_dropout_prob,
......@@ -474,14 +476,14 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense(
hidden_size,
activation=self.activation,
kernel_initializer=self.initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='transform/dense')
if hidden_size > embedding_width:
self.extra_output_weights = self.add_weight(
'extra_output_weights',
shape=(self._vocab_size, hidden_size - embedding_width),
initializer=self.initializer,
initializer=tf_utils.clone_initializer(self.initializer),
trainable=True)
elif hidden_size == embedding_width:
self.extra_output_weights = None
......
......@@ -18,6 +18,7 @@
import math
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import masked_softmax
......@@ -60,8 +61,6 @@ class VotingAttention(tf.keras.layers.Layer):
def build(self, unused_input_shapes):
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -72,12 +71,16 @@ class VotingAttention(tf.keras.layers.Layer):
output_shape=(None, self._num_heads, self._head_size),
bias_axes="NH",
name="query",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._key_dense = tf.keras.layers.experimental.EinsumDense(
"BAE,ENH->BANH",
output_shape=(None, self._num_heads, self._head_size),
bias_axes="NH",
name="key",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
super(VotingAttention, self).build(unused_input_shapes)
......
......@@ -22,6 +22,8 @@ import string
import numpy as np
import tensorflow as tf
from official.modeling import tf_utils
_CHR_IDX = string.ascii_lowercase
......@@ -347,8 +349,6 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
self._key_shape = tf.TensorShape(key)
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -368,6 +368,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
self._num_heads - self._reuse_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="query",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._key_shape.rank - 1, bound_dims=1, output_dims=2)
......@@ -377,6 +381,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
self._num_heads - self._reuse_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="key",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._value_shape.rank - 1, bound_dims=1, output_dims=2)
......@@ -389,6 +397,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
output_rank - 1, [self._reuse_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="value_reuse",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs))
if self._reuse_heads < self._num_heads:
self._value_dense.append(tf.keras.layers.experimental.EinsumDense(
......@@ -397,6 +409,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
self._num_heads - self._reuse_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="value_new",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs))
# Builds the attention computations for multi-head dot product attention.
......@@ -439,13 +455,17 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
output_shape=_get_output_shape(output_rank - 1, output_shape),
bias_axes=bias_axes if (use_bias and self._use_bias) else None,
name=name,
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs)
def _build_attention(self, rank):
"""Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to
costomize attention computation to replace the default dot-product
customize attention computation to replace the default dot-product
attention.
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment