Commit 8f815365 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 433920851
parent b2c92a84
...@@ -42,6 +42,7 @@ from official.nlp.modeling.layers.relative_attention import TwoStreamRelativeAtt ...@@ -42,6 +42,7 @@ from official.nlp.modeling.layers.relative_attention import TwoStreamRelativeAtt
from official.nlp.modeling.layers.reuse_attention import ReuseMultiHeadAttention from official.nlp.modeling.layers.reuse_attention import ReuseMultiHeadAttention
from official.nlp.modeling.layers.reuse_transformer import ReuseTransformer from official.nlp.modeling.layers.reuse_transformer import ReuseTransformer
from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
from official.nlp.modeling.layers.routing import *
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
from official.nlp.modeling.layers.spectral_normalization import * from official.nlp.modeling.layers.spectral_normalization import *
from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layers for Mixture of Experts (MoE) routing.
For MoE routing, we need to separate a set of tokens to sets of tokens.
Later on, different sets of tokens can potentially go to different experts.
"""
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package="Text")
class TokenImportanceWithMovingAvg(tf.keras.layers.Layer):
"""Routing based on per-token importance value."""
def __init__(self,
vocab_size,
init_importance,
moving_average_beta=0.995,
**kwargs):
self._vocab_size = vocab_size
self._init_importance = init_importance
self._moving_average_beta = moving_average_beta
super(TokenImportanceWithMovingAvg, self).__init__(**kwargs)
def build(self, input_shape):
self._importance_embedding = self.add_weight(
name="importance_embed",
shape=(self._vocab_size),
initializer=tf.keras.initializers.Constant(self._init_importance),
trainable=False)
def get_config(self):
config = {
"vocab_size":
self._vocab_size,
"init_importance":
self._init_importance,
"moving_average_beta":
self._moving_average_beta,
}
base_config = super(TokenImportanceWithMovingAvg, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def update_token_importance(self, token_ids, importance):
token_ids = tf.reshape(token_ids, shape=[-1])
importance = tf.reshape(importance, shape=[-1])
beta = self._moving_average_beta
old_importance = tf.gather(self._importance_embedding, token_ids)
self._importance_embedding.assign(tf.tensor_scatter_nd_update(
self._importance_embedding,
tf.expand_dims(token_ids, axis=1),
old_importance * beta + tf.cast(importance * (1.0 - beta),
dtype=tf.float32)))
def call(self, inputs):
return tf.gather(self._importance_embedding, inputs)
@tf.keras.utils.register_keras_serializable(package="Text")
class SelectTopK(tf.keras.layers.Layer):
"""Select top-k + random-k tokens according to importance."""
def __init__(self,
top_k=None,
random_k=None,
**kwargs):
self._top_k = top_k
self._random_k = random_k
super(SelectTopK, self).__init__(**kwargs)
def get_config(self):
config = {
"top_k":
self._top_k,
"random_k":
self._random_k,
}
base_config = super(SelectTopK, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
if self._random_k is None:
# Pure top-k, not randomness.
pos = tf.argsort(inputs, direction="DESCENDING")
selected = tf.slice(pos, [0, 0], [-1, self._top_k])
not_selected = tf.slice(pos, [0, self._top_k], [-1, -1])
elif self._top_k is None:
# Pure randomness, no top-k.
pos = tf.argsort(tf.random.uniform(shape=tf.shape(inputs)),
direction="DESCENDING")
selected = tf.slice(pos, [0, 0], [-1, self._random_k])
not_selected = tf.slice(pos, [0, self._random_k], [-1, -1])
else:
# Top-k plus randomness.
pos = tf.argsort(inputs, direction="DESCENDING")
selected_top_k = tf.slice(pos, [0, 0], [-1, self._top_k])
pos_left = tf.slice(pos, [0, self._top_k], [-1, -1])
# Randomly shuffle pos_left
sort_index = tf.argsort(
tf.random.uniform(shape=tf.shape(pos_left)),
direction="DESCENDING")
pos_left = tf.gather(pos_left, sort_index, batch_dims=1, axis=1)
selected_rand = tf.slice(pos_left, [0, 0], [-1, self._random_k])
not_selected = tf.slice(pos_left, [0, self._random_k], [-1, -1])
selected = tf.concat([selected_top_k, selected_rand], axis=1)
# Return the indices of selected and not-selected tokens.
return selected, not_selected
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for routing."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.nlp.modeling.layers import routing
class TokenImportanceTest(tf.test.TestCase, parameterized.TestCase):
def test_token_importance(self):
token_importance_embed = routing.TokenImportanceWithMovingAvg(
vocab_size=4,
init_importance=10.0,
moving_average_beta=0.995)
importance = token_importance_embed(np.array([[0, 1], [2, 3]]))
self.assertAllClose(importance, np.array([[10.0, 10.0], [10.0, 10.0]]))
token_importance_embed.update_token_importance(
token_ids=np.array([[0, 1]]),
importance=np.array([[0.0, 0.0]]))
importance = token_importance_embed(np.array([[0, 1], [2, 3]]))
self.assertAllClose(importance, np.array([[9.95, 9.95], [10.0, 10.0]]))
class TopKSelectionTest(tf.test.TestCase, parameterized.TestCase):
def test_top_k_selection(self):
token_selection = routing.SelectTopK(top_k=2)
selected, _ = token_selection(np.array([[0, 1, 2, 3], [4, 3, 2, 1]]))
self.assertAllClose(selected, np.array([[3, 2], [0, 1]]))
def test_random_k_selection(self):
token_selection = routing.SelectTopK(random_k=2)
selected, _ = token_selection(np.array([[0, 1, 2, 3], [4, 3, 2, 1]]))
self.assertAllClose(selected.shape, (2, 2))
def test_top_k_random_k(self):
token_selection = routing.SelectTopK(top_k=1, random_k=1)
selected, _ = token_selection(np.array([[0, 1, 2, 3], [4, 3, 2, 1]]))
self.assertAllClose(selected.shape, (2, 2))
if __name__ == "__main__":
tf.test.main()
task:
model:
encoder:
type: any
any:
token_allow_list: !!python/tuple
- 100 # [UNK]
- 101 # [CLS]
- 102 # [SEP]
- 103 # [MASK]
token_deny_list: !!python/tuple
- 0 # [PAD]
attention_dropout_rate: 0.1
dropout_rate: 0.1
hidden_activation: gelu
hidden_size: 768
initializer_range: 0.02
intermediate_size: 3072
max_position_embeddings: 512
num_attention_heads: 12
num_layers: 12
type_vocab_size: 2
vocab_size: 30522
token_loss_init_value: 10.0
token_loss_beta: 0.995
token_keep_k: 256
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer-based BERT encoder network."""
# pylint: disable=g-classes-have-attributes
from typing import Any, Callable, Optional, Union, Tuple
from absl import logging
import tensorflow as tf
from official.nlp.modeling import layers
_Initializer = Union[str, tf.keras.initializers.Initializer]
_Activation = Union[str, Callable[..., Any]]
_approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True)
class TokenDropBertEncoder(tf.keras.layers.Layer):
"""Bi-directional Transformer-based encoder network with token dropping.
During pretraining, we drop unimportant tokens starting from an intermediate
layer in the model, to make the model focus on important tokens more
efficiently with its limited computational resources. The dropped tokens are
later picked up by the last layer of the model, so that the model still
produces full-length sequences. This approach reduces the pretraining cost of
BERT by 25% while achieving better overall fine-tuning performance on standard
downstream tasks.
Args:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network for each transformer.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network for each transformer.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: The dropout rate to use for the attention layers within
the transformer layers.
token_loss_init_value: The default loss value of a token, when the token is
never masked and predicted.
token_loss_beta: How running average factor for computing the average loss
value of a token.
token_keep_k: The number of tokens you want to keep in the intermediate
layers. The rest will be dropped in those layers.
token_allow_list: The list of token-ids that should not be droped. In the
BERT English vocab, token-id from 1 to 998 contains special tokens such
as [CLS], [SEP]. By default, token_allow_list contains all of these
special tokens.
token_deny_list: The list of token-ids that should always be droped. In the
BERT English vocab, token-id=0 means [PAD]. By default, token_deny_list
contains and only contains [PAD].
initializer: The initialzer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yields the full
output.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to generate
embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
with_dense_inputs: Whether to accept dense embeddings as the input.
"""
def __init__(
self,
vocab_size: int,
hidden_size: int = 768,
num_layers: int = 12,
num_attention_heads: int = 12,
max_sequence_length: int = 512,
type_vocab_size: int = 16,
inner_dim: int = 3072,
inner_activation: _Activation = _approx_gelu,
output_dropout: float = 0.1,
attention_dropout: float = 0.1,
token_loss_init_value: float = 10.0,
token_loss_beta: float = 0.995,
token_keep_k: int = 256,
token_allow_list: Tuple[int, ...] = (100, 101, 102, 103),
token_deny_list: Tuple[int, ...] = (0,),
initializer: _Initializer = tf.keras.initializers.TruncatedNormal(
stddev=0.02),
output_range: Optional[int] = None,
embedding_width: Optional[int] = None,
embedding_layer: Optional[tf.keras.layers.Layer] = None,
norm_first: bool = False,
with_dense_inputs: bool = False,
**kwargs):
# Pops kwargs that are used in V1 implementation.
if 'dict_outputs' in kwargs:
kwargs.pop('dict_outputs')
if 'return_all_encoder_outputs' in kwargs:
kwargs.pop('return_all_encoder_outputs')
if 'intermediate_size' in kwargs:
inner_dim = kwargs.pop('intermediate_size')
if 'activation' in kwargs:
inner_activation = kwargs.pop('activation')
if 'dropout_rate' in kwargs:
output_dropout = kwargs.pop('dropout_rate')
if 'attention_dropout_rate' in kwargs:
attention_dropout = kwargs.pop('attention_dropout_rate')
super().__init__(**kwargs)
activation = tf.keras.activations.get(inner_activation)
initializer = tf.keras.initializers.get(initializer)
if embedding_width is None:
embedding_width = hidden_size
if embedding_layer is None:
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
name='word_embeddings')
else:
self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
max_length=max_sequence_length,
name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
use_one_hot=True,
name='type_embeddings')
self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
self._embedding_dropout = tf.keras.layers.Dropout(
rate=output_dropout, name='embedding_dropout')
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
self._embedding_projection = None
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
name='embedding_projection')
# The first 999 tokens are special tokens such as [PAD], [CLS], [SEP].
# We want to always mask [PAD], and always not to maks [CLS], [SEP].
init_importance = tf.constant(token_loss_init_value, shape=(vocab_size))
if token_allow_list:
init_importance = tf.tensor_scatter_nd_update(
tensor=init_importance,
indices=[[x] for x in token_allow_list],
updates=[1.0e4 for x in token_allow_list])
if token_deny_list:
init_importance = tf.tensor_scatter_nd_update(
tensor=init_importance,
indices=[[x] for x in token_deny_list],
updates=[-1.0e4 for x in token_deny_list])
self._token_importance_embed = layers.TokenImportanceWithMovingAvg(
vocab_size=vocab_size,
init_importance=init_importance,
moving_average_beta=token_loss_beta)
self._token_separator = layers.SelectTopK(top_k=token_keep_k)
self._transformer_layers = []
self._num_layers = num_layers
self._attention_mask_layer = layers.SelfAttentionMask(
name='self_attention_mask')
for i in range(num_layers):
layer = layers.TransformerEncoderBlock(
num_attention_heads=num_attention_heads,
inner_dim=inner_dim,
inner_activation=inner_activation,
output_dropout=output_dropout,
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer,
name='transformer/layer_%d' % i)
self._transformer_layers.append(layer)
self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
name='pooler_transform')
self._config = {
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'inner_dim': inner_dim,
'inner_activation': tf.keras.activations.serialize(activation),
'output_dropout': output_dropout,
'attention_dropout': attention_dropout,
'token_loss_init_value': token_loss_init_value,
'token_loss_beta': token_loss_beta,
'token_keep_k': token_keep_k,
'token_allow_list': token_allow_list,
'token_deny_list': token_deny_list,
'initializer': tf.keras.initializers.serialize(initializer),
'output_range': output_range,
'embedding_width': embedding_width,
'embedding_layer': embedding_layer,
'norm_first': norm_first,
'with_dense_inputs': with_dense_inputs,
}
if with_dense_inputs:
self.inputs = dict(
input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
dense_inputs=tf.keras.Input(
shape=(None, embedding_width), dtype=tf.float32),
dense_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
dense_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
)
else:
self.inputs = dict(
input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32))
def call(self, inputs):
if isinstance(inputs, dict):
word_ids = inputs.get('input_word_ids')
mask = inputs.get('input_mask')
type_ids = inputs.get('input_type_ids')
dense_inputs = inputs.get('dense_inputs', None)
dense_mask = inputs.get('dense_mask', None)
dense_type_ids = inputs.get('dense_type_ids', None)
else:
raise ValueError('Unexpected inputs type to %s.' % self.__class__)
word_embeddings = self._embedding_layer(word_ids)
if dense_inputs is not None:
# Concat the dense embeddings at sequence end.
word_embeddings = tf.concat([word_embeddings, dense_inputs], axis=1)
type_ids = tf.concat([type_ids, dense_type_ids], axis=1)
mask = tf.concat([mask, dense_mask], axis=1)
# absolute position embeddings.
position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = self._type_embedding_layer(type_ids)
embeddings = word_embeddings + position_embeddings + type_embeddings
embeddings = self._embedding_norm_layer(embeddings)
embeddings = self._embedding_dropout(embeddings)
if self._embedding_projection is not None:
embeddings = self._embedding_projection(embeddings)
attention_mask = self._attention_mask_layer(embeddings, mask)
encoder_outputs = []
x = embeddings
# Get token routing.
token_importance = self._token_importance_embed(word_ids)
selected, not_selected = self._token_separator(token_importance)
# For a 12-layer BERT:
# 1. All tokens fist go though 5 transformer layers, then
# 2. Only important tokens go through 1 transformer layer with cross
# attention to unimportant tokens, then
# 3. Only important tokens go through 5 transformer layers without cross
# attention.
# 4. Finally, all tokens go through the last layer.
# Step 1.
for layer in self._transformer_layers[:self._num_layers // 2 - 1]:
x = layer([x, attention_mask])
encoder_outputs.append(x)
# Step 2.
# First, separate important and non-important tokens.
x_selected = tf.gather(x, selected, batch_dims=1, axis=1)
mask_selected = tf.gather(mask, selected, batch_dims=1, axis=1)
attention_mask_token_drop = self._attention_mask_layer(
x_selected, mask_selected)
x_not_selected = tf.gather(x, not_selected, batch_dims=1, axis=1)
mask_not_selected = tf.gather(mask, not_selected, batch_dims=1, axis=1)
attention_mask_token_pass = self._attention_mask_layer(
x_selected, tf.concat([mask_selected, mask_not_selected], axis=1))
x_all = tf.concat([x_selected, x_not_selected], axis=1)
# Then, call transformer layer with cross attention.
x_selected = self._transformer_layers[self._num_layers // 2 - 1](
[x_selected, x_all, attention_mask_token_pass])
encoder_outputs.append(x_selected)
# Step 3.
for layer in self._transformer_layers[self._num_layers // 2:-1]:
x_selected = layer([x_selected, attention_mask_token_drop])
encoder_outputs.append(x_selected)
# Step 4.
# First, merge important and non-important tokens.
x_not_selected = tf.cast(x_not_selected, dtype=x_selected.dtype)
x = tf.concat([x_selected, x_not_selected], axis=1)
indices = tf.concat([selected, not_selected], axis=1)
reverse_indices = tf.argsort(indices)
x = tf.gather(x, reverse_indices, batch_dims=1, axis=1)
# Then, call transformer layer with all tokens.
x = self._transformer_layers[-1]([x, attention_mask])
encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1]
first_token_tensor = last_encoder_output[:, 0, :]
pooled_output = self._pooler_layer(first_token_tensor)
return dict(
sequence_output=encoder_outputs[-1],
pooled_output=pooled_output,
encoder_outputs=encoder_outputs)
def record_mlm_loss(self, mlm_ids: tf.Tensor, mlm_losses: tf.Tensor):
self._token_importance_embed.update_token_importance(
token_ids=mlm_ids, importance=mlm_losses)
def get_embedding_table(self):
return self._embedding_layer.embeddings
def get_embedding_layer(self):
return self._embedding_layer
def get_config(self):
return dict(self._config)
@property
def transformer_layers(self):
"""List of Transformer layers in the encoder."""
return self._transformer_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
@classmethod
def from_config(cls, config, custom_objects=None):
if 'embedding_layer' in config and config['embedding_layer'] is not None:
warn_string = (
'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.')
print('WARNING: ' + warn_string)
logging.warn(warn_string)
return cls(**config)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Token dropping encoder configuration and instantiation."""
import dataclasses
from typing import Tuple
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders
from official.projects.token_dropping import encoder
@dataclasses.dataclass
class TokenDropBertEncoderConfig(encoders.BertEncoderConfig):
token_loss_init_value: float = 10.0
token_loss_beta: float = 0.995
token_keep_k: int = 256
token_allow_list: Tuple[int, ...] = (100, 101, 102, 103)
token_deny_list: Tuple[int, ...] = (0,)
@base_config.bind(TokenDropBertEncoderConfig)
def get_encoder(encoder_cfg: TokenDropBertEncoderConfig):
"""Instantiates 'TokenDropBertEncoder'.
Args:
encoder_cfg: A 'TokenDropBertEncoderConfig'.
Returns:
A 'encoder.TokenDropBertEncoder' object.
"""
return encoder.TokenDropBertEncoder(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
output_range=encoder_cfg.output_range,
embedding_width=encoder_cfg.embedding_size,
return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs,
dict_outputs=True,
norm_first=encoder_cfg.norm_first,
token_loss_init_value=encoder_cfg.token_loss_init_value,
token_loss_beta=encoder_cfg.token_loss_beta,
token_keep_k=encoder_cfg.token_keep_k,
token_allow_list=encoder_cfg.token_allow_list,
token_deny_list=encoder_cfg.token_deny_list)
This diff is collapsed.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Token dropping BERT experiment configurations.
Only pretraining configs. Token dropping BERT's checkpoints can be used directly
for the regular BERT. So you can just use the regular BERT for finetuning.
"""
# pylint: disable=g-doc-return-or-yield,line-too-long
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import optimization
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader
from official.projects.token_dropping import encoder_config
from official.projects.token_dropping import masked_lm
@exp_factory.register_config_factory('token_drop_bert/pretraining')
def token_drop_bert_pretraining() -> cfg.ExperimentConfig:
"""BERT pretraining with token dropping."""
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(enable_xla=True),
task=masked_lm.TokenDropMaskedLMConfig(
model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
any=encoder_config.TokenDropBertEncoderConfig(
vocab_size=30522, num_layers=1, token_keep_k=64),
type='any')),
train_data=pretrain_dataloader.BertPretrainDataConfig(),
validation_data=pretrain_dataloader.BertPretrainDataConfig(
is_training=False)),
trainer=cfg.TrainerConfig(
train_steps=1000000,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate':
0.01,
'exclude_from_weight_decay':
['LayerNorm', 'layer_norm', 'bias'],
}
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 1e-4,
'end_learning_rate': 0.0,
}
},
'warmup': {
'type': 'polynomial'
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Masked language task."""
import dataclasses
from typing import Tuple
import tensorflow as tf
from official.core import task_factory
from official.nlp.tasks import masked_lm
@dataclasses.dataclass
class TokenDropMaskedLMConfig(masked_lm.MaskedLMConfig):
"""The model config."""
pass
@task_factory.register_task_cls(TokenDropMaskedLMConfig)
class TokenDropMaskedLMTask(masked_lm.MaskedLMTask):
"""Task object for Mask language modeling."""
def build_losses(self,
labels,
model_outputs,
metrics,
aux_losses=None) -> Tuple[tf.Tensor, tf.Tensor]:
"""Return the final loss, and the masked-lm loss."""
with tf.name_scope('MaskedLMTask/losses'):
metrics = dict([(metric.name, metric) for metric in metrics])
lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
labels['masked_lm_ids'],
tf.cast(model_outputs['mlm_logits'], tf.float32),
from_logits=True)
lm_label_weights = labels['masked_lm_weights']
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses *
lm_label_weights)
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss)
metrics['lm_example_loss'].update_state(mlm_loss)
if 'next_sentence_labels' in labels:
sentence_labels = labels['next_sentence_labels']
sentence_outputs = tf.cast(
model_outputs['next_sentence'], dtype=tf.float32)
sentence_loss = tf.reduce_mean(
tf.keras.losses.sparse_categorical_crossentropy(
sentence_labels, sentence_outputs, from_logits=True))
metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss + sentence_loss
else:
total_loss = mlm_loss
if aux_losses:
total_loss += tf.add_n(aux_losses)
return total_loss, lm_prediction_losses
def train_step(self, inputs, model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer, metrics):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
with tf.GradientTape() as tape:
outputs = model(inputs, training=True)
# Computes per-replica loss.
loss, lm_prediction_losses = self.build_losses(
labels=inputs,
model_outputs=outputs,
metrics=metrics,
aux_losses=model.losses)
model.encoder_network.record_mlm_loss(
mlm_ids=inputs['masked_lm_ids'],
mlm_losses=lm_prediction_losses)
if self.task_config.scale_loss:
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
tvars = model.trainable_variables
if self.task_config.scale_loss:
grads = tape.gradient(scaled_loss, tvars)
else:
grads = tape.gradient(loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss}
def validation_step(self, inputs, model: tf.keras.Model, metrics):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
outputs = self.inference_step(inputs, model)
loss, _ = self.build_losses(
labels=inputs,
model_outputs=outputs,
metrics=metrics,
aux_losses=model.losses)
self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss}
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.tasks.masked_lm."""
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader
from official.projects.token_dropping import encoder_config
from official.projects.token_dropping import masked_lm
class MLMTaskTest(tf.test.TestCase):
def test_task(self):
config = masked_lm.TokenDropMaskedLMConfig(
init_checkpoint=self.get_temp_dir(),
scale_loss=True,
model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
any=encoder_config.TokenDropBertEncoderConfig(
vocab_size=30522, num_layers=1, token_keep_k=64),
type="any"),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
]),
train_data=pretrain_dataloader.BertPretrainDataConfig(
input_path="dummy",
max_predictions_per_seq=20,
seq_length=128,
global_batch_size=1))
task = masked_lm.TokenDropMaskedLMTask(config)
model = task.build_model()
metrics = task.build_metrics()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
# Saves a checkpoint.
ckpt = tf.train.Checkpoint(model=model, **model.checkpoint_items)
ckpt.save(config.init_checkpoint)
task.initialize(model)
if __name__ == "__main__":
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A customized training binary for running token dropping experiments."""
from absl import app
from absl import flags
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.projects.token_dropping import experiment_configs # pylint: disable=unused-import
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu,
**params.runtime.model_parallelism())
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
task:
init_checkpoint: ''
model:
cls_heads: [{activation: tanh, cls_token_idx: 0, dropout_rate: 0.1, inner_dim: 768, name: next_sentence, num_classes: 2}]
train_data:
drop_remainder: true
global_batch_size: 512
input_path: /path-to-data/wikipedia.tfrecord*,/path-to-data/books.tfrecord*
is_training: true
max_predictions_per_seq: 76
seq_length: 512
use_next_sentence_label: true
use_position_id: false
use_v2_feature_names: true
validation_data:
drop_remainder: false
global_batch_size: 512
input_path: /path-to-data/wikipedia.tfrecord*,/path-to-data/books.tfrecord*
is_training: false
max_predictions_per_seq: 76
seq_length: 512
use_next_sentence_label: true
use_position_id: false
use_v2_feature_names: true
trainer:
checkpoint_interval: 20000
max_to_keep: 5
optimizer_config:
learning_rate:
polynomial:
cycle: false
decay_steps: 1000000
end_learning_rate: 0.0
initial_learning_rate: 0.0001
power: 1.0
type: polynomial
optimizer:
type: adamw
warmup:
polynomial:
power: 1
warmup_steps: 10000
type: polynomial
steps_per_loop: 1000
summary_interval: 1000
train_steps: 1000000
validation_interval: 1000
validation_steps: 64
task:
init_checkpoint: ''
model:
cls_heads: []
train_data:
drop_remainder: true
global_batch_size: 512
input_path: /path-to-packed-data/wikipedia.tfrecord*,/path-to-packed-data/books.tfrecord*
is_training: true
max_predictions_per_seq: 76
seq_length: 512
use_next_sentence_label: false
use_position_id: false
use_v2_feature_names: true
validation_data:
drop_remainder: false
global_batch_size: 512
input_path: /path-to-packed-data/wikipedia.tfrecord*,/path-to-packed-data/books.tfrecord*
is_training: false
max_predictions_per_seq: 76
seq_length: 512
use_next_sentence_label: false
use_position_id: false
use_v2_feature_names: true
trainer:
checkpoint_interval: 20000
max_to_keep: 5
optimizer_config:
learning_rate:
polynomial:
cycle: false
decay_steps: 1000000
end_learning_rate: 0.0
initial_learning_rate: 0.0001
power: 1.0
type: polynomial
optimizer:
type: adamw
warmup:
polynomial:
power: 1
warmup_steps: 10000
type: polynomial
steps_per_loop: 1000
summary_interval: 1000
train_steps: 1000000
validation_interval: 1000
validation_steps: 64
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment