Commit 23804bc5 authored by xinliupitt's avatar xinliupitt
Browse files

transformer, attention layers

parent bda18166
......@@ -523,9 +523,8 @@ class CachedAttention(MultiHeadAttention):
# Take the dot product between "query" and "key" to get the raw
# attention scores.
query = tf.multiply(query,1.0 / math.sqrt(float(self._key_size)))
attention_scores = tf.einsum(self._dot_product_equation, key, query)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, F, T]
......
......@@ -65,6 +65,9 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
**kwargs):
super(Transformer, self).__init__(**kwargs)
......@@ -81,6 +84,9 @@ class Transformer(tf.keras.layers.Layer):
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
......@@ -117,6 +123,7 @@ class Transformer(tf.keras.layers.Layer):
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
use_bias=self._use_bias,
name="self_attention",
**common_kwargs)
# pylint: disable=protected-access
......@@ -132,7 +139,7 @@ class Transformer(tf.keras.layers.Layer):
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=1e-12,
epsilon=self._norm_epsilon,
dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
......@@ -157,7 +164,8 @@ class Transformer(tf.keras.layers.Layer):
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon,
dtype=tf.float32)
super(Transformer, self).build(input_shape)
......@@ -203,13 +211,22 @@ class Transformer(tf.keras.layers.Layer):
target_tensor = input_tensor[:, 0:self._output_range, :]
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
target_tensor = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(target_tensor +
attention_output)
if self._norm_first:
attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(target_tensor +
attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
......@@ -219,7 +236,10 @@ class Transformer(tf.keras.layers.Layer):
# is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
layer_output = tf.cast(layer_output, tf.float32)
layer_output = self._output_layer_norm(layer_output + attention_output)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output)
return layer_output
......@@ -273,6 +293,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
**kwargs):
super(TransformerDecoderLayer, self).__init__(**kwargs)
self.num_attention_heads = num_attention_heads
......@@ -289,6 +312,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else:
......@@ -318,6 +344,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
dropout=self.attention_dropout_rate,
use_bias=self._use_bias,
name="self_attention",
**common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense(
......@@ -330,13 +357,15 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
rate=self.dropout_rate)
self.self_attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", axis=-1, epsilon=1e-12))
name="self_attention_layer_norm",
axis=-1, epsilon=self._norm_epsilon))
# Encoder-decoder attention.
self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
dropout=self.attention_dropout_rate,
output_shape=hidden_size,
use_bias=self._use_bias,
name="attention/encdec",
**common_kwargs)
......@@ -344,7 +373,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
rate=self.dropout_rate)
self.encdec_attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12))
name="attention/encdec_output_layer_norm",
axis=-1, epsilon=self._norm_epsilon))
# Feed-forward projection.
self.intermediate_dense = tf.keras.layers.experimental.EinsumDense(
......@@ -363,7 +393,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
**common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12)
name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon)
super(TransformerDecoderLayer, self).build(input_shape)
def common_layers_with_encoder(self):
......@@ -384,6 +414,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"TransformerDecoderLayer must have 4 inputs, but it got: %d" %
len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
source_tensor = input_tensor
if self._norm_first:
input_tensor = self.self_attention_layer_norm(input_tensor)
self_attention_output, cache = self.self_attention(
query=input_tensor,
value=input_tensor,
......@@ -391,8 +424,15 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
cache=cache,
decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_dropout(self_attention_output)
self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output)
if self._norm_first:
self_attention_output = source_tensor + self_attention_output
else:
self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output)
if self._norm_first:
source_self_attention_output = self_attention_output
self_attention_output = self.encdec_attention_layer_norm(
self_attention_output)
cross_attn_inputs = dict(
query=self_attention_output,
value=memory,
......@@ -402,13 +442,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
cross_attn_inputs["context_attention_weights"] = inputs[-1]
attention_output = self.encdec_attention(**cross_attn_inputs)
attention_output = self.encdec_attention_dropout(attention_output)
attention_output = self.encdec_attention_layer_norm(self_attention_output +
attention_output)
if self._norm_first:
attention_output = source_self_attention_output + attention_output
else:
attention_output = self.encdec_attention_layer_norm(
self_attention_output +
attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self.output_layer_norm(attention_output)
intermediate_output = self.intermediate_dense(attention_output)
intermediate_output = self.intermediate_activation_layer(
intermediate_output)
layer_output = self.output_dense(intermediate_output)
layer_output = self.output_dropout(layer_output)
layer_output = self.output_layer_norm(layer_output + attention_output)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self.output_layer_norm(layer_output + attention_output)
return layer_output, cache
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment