Commit 59b5985e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 453490507
parent ee708859
...@@ -34,6 +34,8 @@ from official.nlp.modeling.layers.mobile_bert_layers import MobileBertMaskedLM ...@@ -34,6 +34,8 @@ from official.nlp.modeling.layers.mobile_bert_layers import MobileBertMaskedLM
from official.nlp.modeling.layers.mobile_bert_layers import MobileBertTransformer from official.nlp.modeling.layers.mobile_bert_layers import MobileBertTransformer
from official.nlp.modeling.layers.multi_channel_attention import * from official.nlp.modeling.layers.multi_channel_attention import *
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.pack_optimization import PackBertEmbeddings
from official.nlp.modeling.layers.pack_optimization import StridedTransformerEncoderBlock
from official.nlp.modeling.layers.position_embedding import PositionEmbedding from official.nlp.modeling.layers.position_embedding import PositionEmbedding
from official.nlp.modeling.layers.position_embedding import RelativePositionBias from official.nlp.modeling.layers.position_embedding import RelativePositionBias
from official.nlp.modeling.layers.position_embedding import RelativePositionEmbedding from official.nlp.modeling.layers.position_embedding import RelativePositionEmbedding
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pack sequence optimization on accelerators."""
from typing import Dict
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import self_attention_mask
from official.nlp.modeling.layers import transformer_encoder_block
def _packing_mask(segment_id, source_segment_id, dtype=tf.float32):
"""Calculates a segment mask for attention.
Args:
segment_id: [B, T]
source_segment_id: [B, S]
dtype: data type of generated mask.
Returns:
segment_mask: [B, T, S]
"""
if segment_id is None or source_segment_id is None:
return None
# Compute [B, T, S] = [B, T, 1] == [B, 1, S]
return tf.cast(
tf.equal(
tf.expand_dims(segment_id, 2), tf.expand_dims(source_segment_id, 1)),
dtype=dtype)
@tf.keras.utils.register_keras_serializable(package='Text')
class PackBertEmbeddings(tf.keras.layers.Layer):
"""Performs packing tricks for BERT inputs to improve TPU utilization."""
def __init__(self, pack_sequences: int, **kwargs):
super().__init__(**kwargs)
self.pack_sequences = pack_sequences
def call(self, input_embeddings: tf.Tensor,
input_mask: tf.Tensor) -> Dict[str, tf.Tensor]:
batch_size, seq_len, embedding_dim = tf_utils.get_shape_list(
input_embeddings, expected_rank=3)
example_ids = None
reduced_batch_size = batch_size // self.pack_sequences
packed_seq_len = self.pack_sequences * seq_len
packed_embeddings = tf.reshape(
input_embeddings, [reduced_batch_size, packed_seq_len, embedding_dim])
input_mask = tf.reshape(input_mask, [reduced_batch_size, packed_seq_len])
example_ids = 1 + tf.range(self.pack_sequences)
# Shape: [batch_size, seq_len, pack_sequences].
example_ids = tf.tile(example_ids[None, :, None],
[reduced_batch_size, 1, seq_len])
example_ids = tf.reshape(example_ids, [reduced_batch_size, packed_seq_len])
example_ids = tf.where(
tf.math.equal(input_mask, 0), tf.zeros_like(example_ids), example_ids)
packing_mask = _packing_mask(example_ids, example_ids, dtype=tf.bool)
attention_mask = self_attention_mask.get_mask(
packed_embeddings, input_mask, dtype=tf.bool)
combined_attention_mask = tf.cast(
tf.math.logical_and(attention_mask, packing_mask), tf.float32)
return dict(
packed_embeddings=packed_embeddings,
combined_attention_mask=combined_attention_mask)
@tf.keras.utils.register_keras_serializable(package='Text')
class StridedTransformerEncoderBlock(
transformer_encoder_block.TransformerEncoderBlock):
"""Transformer layer for packing optimization to stride over inputs."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self._output_range is not None:
raise ValueError('StridedTransformerEncoderBlock does not '
'support `output_range` argument.')
def call(self, inputs, stride: tf.Tensor):
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError('Unexpected inputs to %s with length at %d' %
(self.__class__, len(inputs)))
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
if self._norm_first:
source_tensor = input_tensor[:, ::stride, :]
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm_kv(key_value)
target_tensor = input_tensor[:, ::stride, :]
if attention_mask is not None:
attention_mask = attention_mask[:, ::stride, :]
if key_value is None:
key_value = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
# Important to not combine `self._norm_first` and
# `self._use_query_residual` into one if clause because else is only for
# `_norm_first == False`.
if self._use_query_residual:
attention_output = source_tensor + attention_output
else:
if self._use_query_residual:
attention_output = target_tensor + attention_output
attention_output = self._attention_layer_norm(attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
inner_output = self._intermediate_dense(attention_output)
inner_output = self._intermediate_activation_layer(inner_output)
inner_output = self._inner_dropout_layer(inner_output)
layer_output = self._output_dense(inner_output)
layer_output = self._output_dropout(layer_output)
if self._norm_first:
return source_attention_output + layer_output
layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm(layer_output + attention_output)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for pack_optimization."""
import tensorflow as tf
from official.nlp.modeling.layers import pack_optimization
class PackOptimizationTest(tf.test.TestCase):
def test_bert_embedding_packing(self):
batch_size, seq_len, embed_dim = 2, 4, 8
pack_sequences = 2
token_and_position_embed = tf.ones((batch_size, seq_len, embed_dim),
dtype=tf.float32)
input_mask = tf.ones((batch_size, seq_len), dtype=tf.int32)
layer = pack_optimization.PackBertEmbeddings(pack_sequences=pack_sequences)
outputs = layer(token_and_position_embed, input_mask)
self.assertEqual(outputs["packed_embeddings"].shape, (1, 8, embed_dim))
self.assertEqual(outputs["combined_attention_mask"].shape, (1, 8, 8))
def test_strided_transformer_encoder_block(self):
inputs = tf.zeros((2, 4, 8), dtype=tf.float32)
attention_mask = tf.ones((2, 4, 4), dtype=tf.float32)
transformer = pack_optimization.StridedTransformerEncoderBlock(
num_attention_heads=2, inner_dim=4, inner_activation="relu")
_ = transformer([inputs, attention_mask],
stride=tf.constant(2, dtype=tf.int32))
if __name__ == "__main__":
tf.test.main()
...@@ -13,10 +13,38 @@ ...@@ -13,10 +13,38 @@
# limitations under the License. # limitations under the License.
"""Keras layer that creates a self-attention mask.""" """Keras layer that creates a self-attention mask."""
from typing import Optional
import tensorflow as tf import tensorflow as tf
def get_mask(inputs: tf.Tensor,
to_mask: tf.Tensor,
dtype: Optional[tf.DType] = None) -> tf.Tensor:
"""Gets a 3D self-attention mask.
Args:
inputs: from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length,
...].
to_mask: int32 Tensor of shape [batch_size, to_seq_length].
dtype: the output Tensor dtype.
Returns:
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
"""
from_shape = tf.shape(inputs)
batch_size = from_shape[0]
from_seq_length = from_shape[1]
dtype = inputs.dtype if dtype is None else dtype
to_shape = tf.shape(to_mask)
to_seq_length = to_shape[1]
to_mask = tf.cast(
tf.reshape(to_mask, [batch_size, 1, to_seq_length]), dtype=dtype)
return tf.broadcast_to(to_mask, [batch_size, from_seq_length, to_seq_length])
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class SelfAttentionMask(tf.keras.layers.Layer): class SelfAttentionMask(tf.keras.layers.Layer):
"""Create 3D attention mask from a 2D tensor mask. """Create 3D attention mask from a 2D tensor mask.
...@@ -33,16 +61,4 @@ class SelfAttentionMask(tf.keras.layers.Layer): ...@@ -33,16 +61,4 @@ class SelfAttentionMask(tf.keras.layers.Layer):
if isinstance(inputs, list) and to_mask is None: if isinstance(inputs, list) and to_mask is None:
to_mask = inputs[1] to_mask = inputs[1]
inputs = inputs[0] inputs = inputs[0]
from_shape = tf.shape(inputs) return get_mask(inputs, to_mask)
batch_size = from_shape[0]
from_seq_length = from_shape[1]
to_shape = tf.shape(to_mask)
to_seq_length = to_shape[1]
to_mask = tf.cast(
tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
dtype=inputs.dtype)
return tf.broadcast_to(to_mask,
[batch_size, from_seq_length, to_seq_length])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment