Commit f8946f24 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 464649779
parent 0ff8db0a
......@@ -22,26 +22,6 @@ from official.nlp.modeling.layers import transformer_encoder_block
from official.nlp.modeling.layers import transformer_scaffold
def _packing_mask(segment_id, source_segment_id, dtype=tf.float32):
"""Calculates a segment mask for attention.
Args:
segment_id: [B, T]
source_segment_id: [B, S]
dtype: data type of generated mask.
Returns:
segment_mask: [B, T, S]
"""
if segment_id is None or source_segment_id is None:
return None
# Compute [B, T, S] = [B, T, 1] == [B, 1, S]
return tf.cast(
tf.equal(
tf.expand_dims(segment_id, 2), tf.expand_dims(source_segment_id, 1)),
dtype=dtype)
@tf.keras.utils.register_keras_serializable(package='Text')
class PackBertEmbeddings(tf.keras.layers.Layer):
"""Performs packing tricks for BERT inputs to improve TPU utilization."""
......@@ -54,7 +34,6 @@ class PackBertEmbeddings(tf.keras.layers.Layer):
input_mask: tf.Tensor) -> Dict[str, tf.Tensor]:
batch_size, seq_len, embedding_dim = tf_utils.get_shape_list(
input_embeddings, expected_rank=3)
example_ids = None
reduced_batch_size = batch_size // self.pack_sequences
packed_seq_len = self.pack_sequences * seq_len
packed_embeddings = tf.reshape(
......@@ -67,7 +46,10 @@ class PackBertEmbeddings(tf.keras.layers.Layer):
example_ids = tf.reshape(example_ids, [reduced_batch_size, packed_seq_len])
example_ids = tf.where(
tf.math.equal(input_mask, 0), tf.zeros_like(example_ids), example_ids)
packing_mask = _packing_mask(example_ids, example_ids, dtype=tf.bool)
packing_mask = tf.cast(
tf.equal(
tf.expand_dims(example_ids, 2), tf.expand_dims(example_ids, 1)),
dtype=tf.bool)
attention_mask = self_attention_mask.get_mask(
packed_embeddings, input_mask, dtype=tf.bool)
......@@ -225,21 +207,22 @@ class StridedTransformerScaffold(transformer_scaffold.TransformerScaffold):
target_tensor = input_tensor[:, ::stride, :]
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask,
query=target_tensor,
value=key_value,
attention_mask=attention_mask,
training=training)
attention_output = self._attention_dropout(attention_output,
training=training)
attention_output = self._attention_dropout(
attention_output, training=training)
if self._norm_first:
attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(target_tensor +
attention_output,
training=training)
attention_output = self._attention_layer_norm(
target_tensor + attention_output, training=training)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output,
training=training)
attention_output = self._output_layer_norm(
attention_output, training=training)
if self._feedforward_block is None:
intermediate_output = self._intermediate_dense(attention_output)
......@@ -251,17 +234,17 @@ class StridedTransformerScaffold(transformer_scaffold.TransformerScaffold):
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output,
training=training)
layer_output = self._output_layer_norm(
layer_output + attention_output, training=training)
else:
if self._norm_first:
# if norm_first, assume the feedforward block will not apply layer norm
layer_output = self._feedforward_block(attention_output,
training=training)
layer_output = self._feedforward_block(
attention_output, training=training)
layer_output += source_attention_output
else:
# if not norm_first, assume that the feedforwad does apply layer norm
layer_output = self._feedforward_block(attention_output,
training=training)
layer_output = self._feedforward_block(
attention_output, training=training)
return layer_output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment