Commit d3d4177d authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 395996268
parent a173db62
...@@ -116,6 +116,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -116,6 +116,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._attention_initializer = self._kernel_initializer self._attention_initializer = self._kernel_initializer
self._attention_axes = attention_axes self._attention_axes = attention_axes
def _maybe_build(self, inputs):
super()._maybe_build(inputs[:1])
def build(self, input_shape): def build(self, input_shape):
if isinstance(input_shape, tf.TensorShape): if isinstance(input_shape, tf.TensorShape):
input_tensor_shape = input_shape input_tensor_shape = input_shape
...@@ -247,6 +250,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -247,6 +250,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
[`query tensor`, `key value tensor`, `attention mask`] to have separate [`query tensor`, `key value tensor`, `attention mask`] to have separate
input streams for the query, and key/value to the multi-head input streams for the query, and key/value to the multi-head
attention. attention.
[`query tensor`, `key value tensor`, `attention mask`, `pos_embed`] to
have an additional pos_embed that is added to the query and key of
every self-attention layer.
Returns: Returns:
An output tensor with the same dimensions as input/query tensor. An output tensor with the same dimensions as input/query tensor.
...@@ -255,13 +261,18 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -255,13 +261,18 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
if len(inputs) == 2: if len(inputs) == 2:
input_tensor, attention_mask = inputs input_tensor, attention_mask = inputs
key_value = None key_value = None
pos_embed = None
elif len(inputs) == 3: elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs input_tensor, key_value, attention_mask = inputs
pos_embed = None
elif len(inputs) == 4:
input_tensor, key_value, attention_mask, pos_embed = inputs
else: else:
raise ValueError("Unexpected inputs to %s with length at %d" % raise ValueError("Unexpected inputs to %s with length at %d" %
(self.__class__, len(inputs))) (self.__class__, len(inputs)))
else: else:
input_tensor, key_value, attention_mask = (inputs, None, None) input_tensor, key_value, attention_mask, pos_embed = (inputs, None, None,
None)
if self._output_range: if self._output_range:
if self._norm_first: if self._norm_first:
...@@ -282,8 +293,14 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -282,8 +293,14 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
if key_value is None: if key_value is None:
key_value = input_tensor key_value = input_tensor
if pos_embed is None:
query = target_tensor
key = key_value
else:
query = target_tensor + pos_embed
key = key_value + pos_embed
attention_output = self._attention_layer( attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask) query=query, key=key, value=key_value, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
if self._norm_first: if self._norm_first:
attention_output = source_tensor + attention_output attention_output = source_tensor + attention_output
......
...@@ -232,6 +232,9 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -232,6 +232,9 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
else: else:
self._cross_attention_cls = attention.MultiHeadAttention self._cross_attention_cls = attention.MultiHeadAttention
def _maybe_build(self, inputs):
super()._maybe_build(inputs[:1])
def build(self, input_shape): def build(self, input_shape):
target_tensor_shape = tf.TensorShape(input_shape[0]) target_tensor_shape = tf.TensorShape(input_shape[0])
if len(target_tensor_shape.as_list()) != 3: if len(target_tensor_shape.as_list()) != 3:
...@@ -370,22 +373,57 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -370,22 +373,57 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self.intermediate_dense, self.output_dense, self.output_layer_norm self.intermediate_dense, self.output_dense, self.output_layer_norm
] ]
def call(self, inputs, cache=None, decode_loop_step=None): def _parse_inputs(self, inputs, multi_channel_cross_attention):
if self.multi_channel_cross_attention: if multi_channel_cross_attention:
if len(inputs) != 5: if len(inputs) < 5:
raise ValueError( raise ValueError(
"TransformerDecoderBlock must have 5 inputs, when it uses " "TransformerDecoderBlock must have at least 5 inputs, when it uses "
"multi_channel_cross_attention. But it got: %d" % len(inputs)) "multi_channel_cross_attention. But it got: %d" % len(inputs))
elif len(inputs) != 4: elif len(inputs) == 5:
raise ValueError( input_tensor, memory, attention_mask, self_attention_mask, context_attention_weights = inputs
"TransformerDecoderBlock must have 4 inputs, but it got: %d" % input_pos_embed = None
len(inputs)) memory_pos_embed = None
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4] elif len(inputs) == 6:
input_tensor, memory, attention_mask, self_attention_mask, context_attention_weights, input_pos_embed = inputs
memory_pos_embed = None
else:
input_tensor, memory, attention_mask, self_attention_mask, context_attention_weights, input_pos_embed, memory_pos_embed = inputs[:
7]
else:
context_attention_weights = None
if len(inputs) < 4:
raise ValueError(
"TransformerDecoderBlock must have at leaset 4 inputs, but it "
"got: %d" % len(inputs))
elif len(inputs) == 4:
input_tensor, memory, attention_mask, self_attention_mask = inputs
input_pos_embed = None
memory_pos_embed = None
elif len(inputs) == 5:
input_tensor, memory, attention_mask, self_attention_mask, input_pos_embed = inputs
memory_pos_embed = None
else:
input_tensor, memory, attention_mask, self_attention_mask, input_pos_embed, memory_pos_embed = inputs[:
6]
return input_tensor, memory, attention_mask, self_attention_mask, context_attention_weights, input_pos_embed, memory_pos_embed
def call(self, inputs, cache=None, decode_loop_step=None):
input_tensor, memory, attention_mask, self_attention_mask, context_attention_weights, input_pos_embed, memory_pos_embed = self._parse_inputs(
inputs, self.multi_channel_cross_attention)
source_tensor = input_tensor source_tensor = input_tensor
if self._norm_first: if self._norm_first:
input_tensor = self.self_attention_layer_norm(input_tensor) input_tensor = self.self_attention_layer_norm(input_tensor)
if input_pos_embed is None:
self_attn_query = input_tensor
self_attn_key = input_tensor
else:
self_attn_query = input_tensor + input_pos_embed
self_attn_key = input_tensor + input_pos_embed
self_attention_output, cache = self.self_attention( self_attention_output, cache = self.self_attention(
query=input_tensor, query=self_attn_query,
key=self_attn_key,
value=input_tensor, value=input_tensor,
attention_mask=self_attention_mask, attention_mask=self_attention_mask,
cache=cache, cache=cache,
...@@ -400,13 +438,22 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -400,13 +438,22 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
source_self_attention_output = self_attention_output source_self_attention_output = self_attention_output
self_attention_output = self.encdec_attention_layer_norm( self_attention_output = self.encdec_attention_layer_norm(
self_attention_output) self_attention_output)
if input_pos_embed is None:
cross_attn_query = self_attention_output
else:
cross_attn_query = self_attention_output + input_pos_embed
if memory_pos_embed is None:
cross_attn_key = memory
else:
cross_attn_key = memory + memory_pos_embed
cross_attn_inputs = dict( cross_attn_inputs = dict(
query=self_attention_output, query=cross_attn_query,
key=cross_attn_key,
value=memory, value=memory,
attention_mask=attention_mask) attention_mask=attention_mask)
if self.multi_channel_cross_attention: if self.multi_channel_cross_attention:
# Accesses the 5-th input tensor for the doc-attention probabilities. # Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs["context_attention_weights"] = inputs[-1] cross_attn_inputs["context_attention_weights"] = context_attention_weights
attention_output = self.encdec_attention(**cross_attn_inputs) attention_output = self.encdec_attention(**cross_attn_inputs)
attention_output = self.encdec_attention_dropout(attention_output) attention_output = self.encdec_attention_dropout(attention_output)
if self._norm_first: if self._norm_first:
......
...@@ -425,7 +425,7 @@ class TransformerEncoder(tf.keras.layers.Layer): ...@@ -425,7 +425,7 @@ class TransformerEncoder(tf.keras.layers.Layer):
base_config = super(TransformerEncoder, self).get_config() base_config = super(TransformerEncoder, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, encoder_inputs, attention_mask=None): def call(self, encoder_inputs, attention_mask=None, pos_embed=None):
"""Return the output of the encoder. """Return the output of the encoder.
Args: Args:
...@@ -433,14 +433,17 @@ class TransformerEncoder(tf.keras.layers.Layer): ...@@ -433,14 +433,17 @@ class TransformerEncoder(tf.keras.layers.Layer):
hidden_size)`. hidden_size)`.
attention_mask: A mask for the encoder self-attention layer with shape attention_mask: A mask for the encoder self-attention layer with shape
`(batch_size, input_length, input_length)`. `(batch_size, input_length, input_length)`.
pos_embed: A tensor or a float that is added to the query and key of every
self-attention layer. Defaults to None.
Returns: Returns:
Output of encoder which is a `float32` tensor with shape Output of encoder which is a `float32` tensor with shape
`(batch_size, input_length, hidden_size)`. `(batch_size, input_length, hidden_size)`.
""" """
for layer_idx in range(self.num_layers): for layer_idx in range(self.num_layers):
encoder_inputs = self.encoder_layers[layer_idx]( encoder_inputs = self.encoder_layers[layer_idx](
[encoder_inputs, attention_mask]) [encoder_inputs, encoder_inputs, attention_mask, pos_embed])
output_tensor = encoder_inputs output_tensor = encoder_inputs
output_tensor = self.output_normalization(output_tensor) output_tensor = self.output_normalization(output_tensor)
...@@ -519,7 +522,7 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -519,7 +522,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
attention_initializer=attention_initializer(input_shape[2]), attention_initializer=attention_initializer(input_shape[2]),
name=("layer_%d" % i))) name=("layer_%d" % i)))
self.output_normalization = tf.keras.layers.LayerNormalization( self.output_normalization = tf.keras.layers.LayerNormalization(
epsilon=1e-6, dtype="float32") epsilon=self._norm_epsilon, dtype="float32")
super(TransformerDecoder, self).build(input_shape) super(TransformerDecoder, self).build(input_shape)
def get_config(self): def get_config(self):
...@@ -545,7 +548,9 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -545,7 +548,9 @@ class TransformerDecoder(tf.keras.layers.Layer):
cross_attention_mask=None, cross_attention_mask=None,
cache=None, cache=None,
decode_loop_step=None, decode_loop_step=None,
return_all_decoder_outputs=False): return_all_decoder_outputs=False,
input_pos_embed=None,
memory_pos_embed=None):
"""Return the output of the decoder layer stacks. """Return the output of the decoder layer stacks.
Args: Args:
...@@ -565,6 +570,10 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -565,6 +570,10 @@ class TransformerDecoder(tf.keras.layers.Layer):
return_all_decoder_outputs: Return all decoder layer outputs. return_all_decoder_outputs: Return all decoder layer outputs.
Note that the outputs are layer normed. Note that the outputs are layer normed.
This is useful when introducing per layer auxiliary loss. This is useful when introducing per layer auxiliary loss.
input_pos_embed: A tensor or float that is added to the target embedding
in every self-attention and cross-attention layer. Defaults to None.
memory_pos_embed: A tensor or float that is added to the memory embedding
in every cross-attention layer. Defaults to None.
Returns: Returns:
Output of decoder. Output of decoder.
...@@ -575,7 +584,8 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -575,7 +584,8 @@ class TransformerDecoder(tf.keras.layers.Layer):
decoder_outputs = [] decoder_outputs = []
for layer_idx in range(self.num_layers): for layer_idx in range(self.num_layers):
transformer_inputs = [ transformer_inputs = [
output_tensor, memory, cross_attention_mask, self_attention_mask output_tensor, memory, cross_attention_mask, self_attention_mask,
input_pos_embed, memory_pos_embed
] ]
# Gets the cache for decoding. # Gets the cache for decoding.
if cache is None: if cache is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment