"src/graph/sampling/vscode:/vscode.git/clone" did not exist on "e4cc81852def1e0ef52792eef9e0c48424ddc8f7"
Commit 23804bc5 authored by xinliupitt's avatar xinliupitt
Browse files

transformer, attention layers

parent bda18166
...@@ -523,9 +523,8 @@ class CachedAttention(MultiHeadAttention): ...@@ -523,9 +523,8 @@ class CachedAttention(MultiHeadAttention):
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
query = tf.multiply(query,1.0 / math.sqrt(float(self._key_size)))
attention_scores = tf.einsum(self._dot_product_equation, key, query) attention_scores = tf.einsum(self._dot_product_equation, key, query)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, F, T] # `attention_scores` = [B, N, F, T]
......
...@@ -65,6 +65,9 @@ class Transformer(tf.keras.layers.Layer): ...@@ -65,6 +65,9 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer=None, activity_regularizer=None,
kernel_constraint=None, kernel_constraint=None,
bias_constraint=None, bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
**kwargs): **kwargs):
super(Transformer, self).__init__(**kwargs) super(Transformer, self).__init__(**kwargs)
...@@ -81,6 +84,9 @@ class Transformer(tf.keras.layers.Layer): ...@@ -81,6 +84,9 @@ class Transformer(tf.keras.layers.Layer):
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer) self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
def build(self, input_shape): def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
...@@ -117,6 +123,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -117,6 +123,7 @@ class Transformer(tf.keras.layers.Layer):
num_heads=self._num_heads, num_heads=self._num_heads,
key_size=self._attention_head_size, key_size=self._attention_head_size,
dropout=self._attention_dropout_rate, dropout=self._attention_dropout_rate,
use_bias=self._use_bias,
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
# pylint: disable=protected-access # pylint: disable=protected-access
...@@ -132,7 +139,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -132,7 +139,7 @@ class Transformer(tf.keras.layers.Layer):
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", name="self_attention_layer_norm",
axis=-1, axis=-1,
epsilon=1e-12, epsilon=self._norm_epsilon,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
...@@ -157,7 +164,8 @@ class Transformer(tf.keras.layers.Layer): ...@@ -157,7 +164,8 @@ class Transformer(tf.keras.layers.Layer):
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization( self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32) name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon,
dtype=tf.float32)
super(Transformer, self).build(input_shape) super(Transformer, self).build(input_shape)
...@@ -203,13 +211,22 @@ class Transformer(tf.keras.layers.Layer): ...@@ -203,13 +211,22 @@ class Transformer(tf.keras.layers.Layer):
target_tensor = input_tensor[:, 0:self._output_range, :] target_tensor = input_tensor[:, 0:self._output_range, :]
attention_mask = attention_mask[:, 0:self._output_range, :] attention_mask = attention_mask[:, 0:self._output_range, :]
else: else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
target_tensor = input_tensor target_tensor = input_tensor
attention_output = self._attention_layer( attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask) query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(target_tensor + if self._norm_first:
attention_output) attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(target_tensor +
attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
intermediate_output = self._intermediate_dense(attention_output) intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer( intermediate_output = self._intermediate_activation_layer(
intermediate_output) intermediate_output)
...@@ -219,7 +236,10 @@ class Transformer(tf.keras.layers.Layer): ...@@ -219,7 +236,10 @@ class Transformer(tf.keras.layers.Layer):
# is always fp32 for now. Cast layer_output to fp32 for the subsequent # is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add. # add.
layer_output = tf.cast(layer_output, tf.float32) layer_output = tf.cast(layer_output, tf.float32)
layer_output = self._output_layer_norm(layer_output + attention_output) if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output)
return layer_output return layer_output
...@@ -273,6 +293,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -273,6 +293,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
activity_regularizer=None, activity_regularizer=None,
kernel_constraint=None, kernel_constraint=None,
bias_constraint=None, bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
**kwargs): **kwargs):
super(TransformerDecoderLayer, self).__init__(**kwargs) super(TransformerDecoderLayer, self).__init__(**kwargs)
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
...@@ -289,6 +312,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -289,6 +312,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer) self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
if self.multi_channel_cross_attention: if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else: else:
...@@ -318,6 +344,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -318,6 +344,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
key_size=self.attention_head_size, key_size=self.attention_head_size,
dropout=self.attention_dropout_rate, dropout=self.attention_dropout_rate,
use_bias=self._use_bias,
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense( self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense(
...@@ -330,13 +357,15 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -330,13 +357,15 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
rate=self.dropout_rate) rate=self.dropout_rate)
self.self_attention_layer_norm = ( self.self_attention_layer_norm = (
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", axis=-1, epsilon=1e-12)) name="self_attention_layer_norm",
axis=-1, epsilon=self._norm_epsilon))
# Encoder-decoder attention. # Encoder-decoder attention.
self.encdec_attention = self._cross_attention_cls( self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
key_size=self.attention_head_size, key_size=self.attention_head_size,
dropout=self.attention_dropout_rate, dropout=self.attention_dropout_rate,
output_shape=hidden_size, output_shape=hidden_size,
use_bias=self._use_bias,
name="attention/encdec", name="attention/encdec",
**common_kwargs) **common_kwargs)
...@@ -344,7 +373,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -344,7 +373,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
rate=self.dropout_rate) rate=self.dropout_rate)
self.encdec_attention_layer_norm = ( self.encdec_attention_layer_norm = (
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12)) name="attention/encdec_output_layer_norm",
axis=-1, epsilon=self._norm_epsilon))
# Feed-forward projection. # Feed-forward projection.
self.intermediate_dense = tf.keras.layers.experimental.EinsumDense( self.intermediate_dense = tf.keras.layers.experimental.EinsumDense(
...@@ -363,7 +393,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -363,7 +393,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
**common_kwargs) **common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization( self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12) name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon)
super(TransformerDecoderLayer, self).build(input_shape) super(TransformerDecoderLayer, self).build(input_shape)
def common_layers_with_encoder(self): def common_layers_with_encoder(self):
...@@ -384,6 +414,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -384,6 +414,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"TransformerDecoderLayer must have 4 inputs, but it got: %d" % "TransformerDecoderLayer must have 4 inputs, but it got: %d" %
len(inputs)) len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4] input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
source_tensor = input_tensor
if self._norm_first:
input_tensor = self.self_attention_layer_norm(input_tensor)
self_attention_output, cache = self.self_attention( self_attention_output, cache = self.self_attention(
query=input_tensor, query=input_tensor,
value=input_tensor, value=input_tensor,
...@@ -391,8 +424,15 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -391,8 +424,15 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
cache=cache, cache=cache,
decode_loop_step=decode_loop_step) decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_dropout(self_attention_output) self_attention_output = self.self_attention_dropout(self_attention_output)
self_attention_output = self.self_attention_layer_norm( if self._norm_first:
input_tensor + self_attention_output) self_attention_output = source_tensor + self_attention_output
else:
self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output)
if self._norm_first:
source_self_attention_output = self_attention_output
self_attention_output = self.encdec_attention_layer_norm(
self_attention_output)
cross_attn_inputs = dict( cross_attn_inputs = dict(
query=self_attention_output, query=self_attention_output,
value=memory, value=memory,
...@@ -402,13 +442,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -402,13 +442,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
cross_attn_inputs["context_attention_weights"] = inputs[-1] cross_attn_inputs["context_attention_weights"] = inputs[-1]
attention_output = self.encdec_attention(**cross_attn_inputs) attention_output = self.encdec_attention(**cross_attn_inputs)
attention_output = self.encdec_attention_dropout(attention_output) attention_output = self.encdec_attention_dropout(attention_output)
attention_output = self.encdec_attention_layer_norm(self_attention_output + if self._norm_first:
attention_output) attention_output = source_self_attention_output + attention_output
else:
attention_output = self.encdec_attention_layer_norm(
self_attention_output +
attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self.output_layer_norm(attention_output)
intermediate_output = self.intermediate_dense(attention_output) intermediate_output = self.intermediate_dense(attention_output)
intermediate_output = self.intermediate_activation_layer( intermediate_output = self.intermediate_activation_layer(
intermediate_output) intermediate_output)
layer_output = self.output_dense(intermediate_output) layer_output = self.output_dense(intermediate_output)
layer_output = self.output_dropout(layer_output) layer_output = self.output_dropout(layer_output)
layer_output = self.output_layer_norm(layer_output + attention_output) if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self.output_layer_norm(layer_output + attention_output)
return layer_output, cache return layer_output, cache
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment