Commit 4bd15fa6 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 267435985
parent a009f4fb
......@@ -495,7 +495,23 @@ class Attention(tf.keras.layers.Layer):
class Dense3D(tf.keras.layers.Layer):
"""A Dense Layer using 3D kernel with tf.einsum implementation."""
"""A Dense Layer using 3D kernel with tf.einsum implementation.
Attributes:
num_attention_heads: An integer, number of attention heads for each
multihead attention layer.
size_per_head: An integer, hidden size per attention head.
hidden_size: An integer, dimension of the hidden layer.
kernel_initializer: An initializer for the kernel weight.
bias_initializer: An initializer for the bias.
activation: An activation function to use. If nothing is specified, no
activation is applied.
use_bias: A bool, whether the layer uses a bias.
output_projection: A bool, whether the Dense3D layer is used for output
linear projection.
backward_compatible: A bool, whether the variables shape are compatible
with checkpoints converted from TF 1.x.
"""
def __init__(self,
num_attention_heads=12,
......@@ -503,9 +519,11 @@ class Dense3D(tf.keras.layers.Layer):
kernel_initializer=None,
bias_initializer="zeros",
activation=None,
use_bias=True,
output_projection=False,
backward_compatible=False,
**kwargs):
"""Inits Dense3D."""
super(Dense3D, self).__init__(**kwargs)
self.num_attention_heads = num_attention_heads
self.size_per_head = size_per_head
......@@ -513,6 +531,7 @@ class Dense3D(tf.keras.layers.Layer):
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer
self.activation = activation
self.use_bias = use_bias
self.output_projection = output_projection
self.backward_compatible = backward_compatible
......@@ -565,12 +584,15 @@ class Dense3D(tf.keras.layers.Layer):
initializer=self.kernel_initializer,
dtype=self.dtype,
trainable=True)
self.bias = self.add_weight(
"bias",
shape=bias_shape,
initializer=self.bias_initializer,
dtype=self.dtype,
trainable=True)
if self.use_bias:
self.bias = self.add_weight(
"bias",
shape=bias_shape,
initializer=self.bias_initializer,
dtype=self.dtype,
trainable=True)
else:
self.bias = None
super(Dense3D, self).build(input_shape)
def call(self, inputs):
......@@ -588,7 +610,8 @@ class Dense3D(tf.keras.layers.Layer):
"""
if self.backward_compatible:
kernel = tf.keras.backend.reshape(self.kernel, self.kernel_shape)
bias = tf.keras.backend.reshape(self.bias, self.bias_shape)
bias = (tf.keras.backend.reshape(self.bias, self.bias_shape)
if self.use_bias else None)
else:
kernel = self.kernel
bias = self.bias
......@@ -597,7 +620,8 @@ class Dense3D(tf.keras.layers.Layer):
ret = tf.einsum("abcd,cde->abe", inputs, kernel)
else:
ret = tf.einsum("abc,cde->abde", inputs, kernel)
ret += bias
if self.use_bias:
ret += bias
if self.activation is not None:
return self.activation(ret)
return ret
......
......@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.bert import modeling as common_layer
class Attention(tf.keras.layers.Layer):
......@@ -45,14 +46,19 @@ class Attention(tf.keras.layers.Layer):
def build(self, input_shape):
"""Builds the layer."""
# Layers for linearly projecting the queries, keys, and values.
self.q_dense_layer = tf.keras.layers.Dense(
self.hidden_size, use_bias=False, name="q")
self.k_dense_layer = tf.keras.layers.Dense(
self.hidden_size, use_bias=False, name="k")
self.v_dense_layer = tf.keras.layers.Dense(
self.hidden_size, use_bias=False, name="v")
self.output_dense_layer = tf.keras.layers.Dense(
self.hidden_size, use_bias=False, name="output_transform")
size_per_head = self.hidden_size // self.num_heads
self.query_dense_layer = common_layer.Dense3D(
self.num_heads, size_per_head, kernel_initializer="glorot_uniform",
use_bias=False, name="query")
self.key_dense_layer = common_layer.Dense3D(
self.num_heads, size_per_head, kernel_initializer="glorot_uniform",
use_bias=False, name="key")
self.value_dense_layer = common_layer.Dense3D(
self.num_heads, size_per_head, kernel_initializer="glorot_uniform",
use_bias=False, name="value")
self.output_dense_layer = common_layer.Dense3D(
self.num_heads, size_per_head, kernel_initializer="glorot_uniform",
use_bias=False, output_projection=True, name="output_transform")
super(Attention, self).build(input_shape)
def get_config(self):
......@@ -62,73 +68,35 @@ class Attention(tf.keras.layers.Layer):
"attention_dropout": self.attention_dropout,
}
def split_heads(self, x):
"""Split x into different heads, and transpose the resulting value.
The tensor is transposed to insure the inner dimensions hold the correct
values during the matrix multiplication.
Args:
x: A tensor with shape [batch_size, length, hidden_size]
Returns:
A tensor with shape [batch_size, num_heads, length, hidden_size/num_heads]
"""
with tf.name_scope("split_heads"):
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
# Calculate depth of last dimension after it has been split.
depth = (self.hidden_size // self.num_heads)
# Split the last dimension
x = tf.reshape(x, [batch_size, length, self.num_heads, depth])
# Transpose the result
return tf.transpose(x, [0, 2, 1, 3])
def combine_heads(self, x):
"""Combine tensor that has been split.
def call(self, query_input, source_input, bias, training, cache=None,
decode_loop_step=None):
"""Apply attention mechanism to query_input and source_input.
Args:
x: A tensor [batch_size, num_heads, length, hidden_size/num_heads]
Returns:
A tensor with shape [batch_size, length, hidden_size]
"""
with tf.name_scope("combine_heads"):
batch_size = tf.shape(x)[0]
length = tf.shape(x)[2]
x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth]
return tf.reshape(x, [batch_size, length, self.hidden_size])
def call(self, x, y, bias, training, cache=None, decode_loop_step=None):
"""Apply attention mechanism to x and y.
Args:
x: A tensor with shape [batch_size, length_x, hidden_size].
y: A tensor with shape [batch_size, length_y, hidden_size].
bias: A bool, the attention bias that will be added to the result of the
dot product.
query_input: A tensor with shape [batch_size, length_query, hidden_size].
source_input: A tensor with shape [batch_size, length_source,
hidden_size].
bias: A tensor with shape [batch_size, 1, length_query, length_source],
the attention bias that will be added to the result of the dot product.
training: A bool, whether in training mode or not.
cache: (Used during prediction) A dictionary with tensors containing
results of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]}
where i is the current decoded length.
{"k": tensor with shape [batch_size, i, heads, dim_per_head],
"v": tensor with shape [batch_size, i, heads, dim_per_head]}
where i is the current decoded length for non-padded decode, or max
sequence length for padded decode.
decode_loop_step: An integer, step number of the decoding loop. Used only
for autoregressive inference on TPU.
Returns:
Attention layer output with shape [batch_size, length_x, hidden_size]
Attention layer output with shape [batch_size, length_query, hidden_size]
"""
# Linearly project the query, key and value using different learned
# projections. This is in preparation of splitting them into multiple
# heads. Multi-head attention uses multiple queries, keys, and values
# rather than regular attention (which uses a single query, key, value).
query = self.q_dense_layer(x)
key = self.k_dense_layer(y)
value = self.v_dense_layer(y)
# projections. Splitting heads is automatically done during the linear
# projections --> [batch_size, length, num_heads, dim_per_head].
query = self.query_dense_layer(query_input)
key = self.key_dense_layer(source_input)
value = self.value_dense_layer(source_input)
if cache is not None:
# Combine cached keys and values with new keys and values.
......@@ -136,12 +104,12 @@ class Attention(tf.keras.layers.Layer):
cache_k_shape = cache["k"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype),
[1, cache_k_shape[1], 1])
[1, cache_k_shape[1], 1, 1])
key = cache["k"] + key * indices
cache_v_shape = cache["v"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype),
[1, cache_v_shape[1], 1])
[1, cache_v_shape[1], 1, 1])
value = cache["v"] + value * indices
else:
key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1)
......@@ -151,18 +119,13 @@ class Attention(tf.keras.layers.Layer):
cache["k"] = key
cache["v"] = value
# Split query, key, value into heads.
query = self.split_heads(query)
key = self.split_heads(key)
value = self.split_heads(value)
# Scale query to prevent the dot product between query and key from growing
# too large.
depth = (self.hidden_size // self.num_heads)
query *= depth ** -0.5
# Calculate dot product attention
logits = tf.matmul(query, key, transpose_b=True)
logits = tf.einsum("BTNH,BFNH->BNFT", key, query)
logits += bias
# Note that softmax internally performs math operations using float32
# for numeric stability. When training with float16, we keep the input
......@@ -170,12 +133,10 @@ class Attention(tf.keras.layers.Layer):
weights = tf.nn.softmax(logits, name="attention_weights")
if training:
weights = tf.nn.dropout(weights, rate=self.attention_dropout)
attention_output = tf.matmul(weights, value)
# Recombine heads --> [batch_size, length, hidden_size]
attention_output = self.combine_heads(attention_output)
attention_output = tf.einsum("BNFT,BTNH->BFNH", weights, value)
# Run the combined outputs through another linear projection layer.
# Run the outputs through another linear projection layer. Recombining heads
# is automatically done --> [batch_size, length, hidden_size]
attention_output = self.output_dense_layer(attention_output)
return attention_output
......@@ -183,6 +144,7 @@ class Attention(tf.keras.layers.Layer):
class SelfAttention(Attention):
"""Multiheaded self-attention layer."""
def call(self, x, bias, training, cache=None, decode_loop_step=None):
return super(SelfAttention, self).call(x, x, bias, training, cache,
decode_loop_step)
def call(self, query_input, bias, training, cache=None,
decode_loop_step=None):
return super(SelfAttention, self).call(
query_input, query_input, bias, training, cache, decode_loop_step)
......@@ -200,7 +200,7 @@ class Transformer(tf.keras.Model):
# Prepare inputs to decoder layers by shifting targets, adding positional
# encoding and applying dropout.
decoder_inputs = self.embedding_softmax_layer(targets)
decoder_inputs = tf.cast(decoder_inputs, self.params['dtype'])
decoder_inputs = tf.cast(decoder_inputs, self.params["dtype"])
attention_bias = tf.cast(attention_bias, self.params["dtype"])
with tf.name_scope("shift_targets"):
# Shift targets to the right, and remove the last element
......@@ -218,7 +218,7 @@ class Transformer(tf.keras.Model):
# Run values
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
length, dtype=self.params['dtype'])
length, dtype=self.params["dtype"])
outputs = self.decoder_stack(
decoder_inputs,
encoder_outputs,
......@@ -310,16 +310,18 @@ class Transformer(tf.keras.Model):
# pylint: disable=g-complex-comprehension
init_decode_length = (
max_decode_length if self.params["padded_decode"] else 0)
num_heads = self.params["num_heads"]
dim_per_head = self.params["hidden_size"] // num_heads
cache = {
"layer_%d" % layer: {
"k":
tf.zeros([
batch_size, init_decode_length, self.params["hidden_size"]
batch_size, init_decode_length, num_heads, dim_per_head
],
dtype=self.params["dtype"]),
"v":
tf.zeros([
batch_size, init_decode_length, self.params["hidden_size"]
batch_size, init_decode_length, num_heads, dim_per_head
],
dtype=self.params["dtype"])
} for layer in range(self.params["num_hidden_layers"])
......
......@@ -32,6 +32,7 @@ class TransformerLayersTest(tf.test.TestCase):
hidden_size = 64
num_heads = 4
dropout = 0.5
dim_per_head = hidden_size // num_heads
layer = attention_layer.SelfAttention(hidden_size, num_heads, dropout)
self.assertDictEqual(layer.get_config(), {
"hidden_size": hidden_size,
......@@ -42,13 +43,13 @@ class TransformerLayersTest(tf.test.TestCase):
x = tf.ones([1, length, hidden_size])
bias = tf.ones([1])
cache = {
"k": tf.zeros([1, 0, hidden_size]),
"v": tf.zeros([1, 0, hidden_size]),
"k": tf.zeros([1, 0, num_heads, dim_per_head]),
"v": tf.zeros([1, 0, num_heads, dim_per_head]),
}
y = layer(x, bias, training=True, cache=cache)
self.assertEqual(y.shape, (1, length, 64,))
self.assertEqual(cache["k"].shape, (1, length, 64,))
self.assertEqual(cache["v"].shape, (1, length, 64,))
self.assertEqual(cache["k"].shape, (1, length, num_heads, dim_per_head,))
self.assertEqual(cache["v"].shape, (1, length, num_heads, dim_per_head,))
def test_embedding_shared_weights(self):
vocab_size = 50
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment