Commit 318fa0af authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 345500920
parent 15cbdacf
...@@ -207,6 +207,7 @@ def bigbird_block_sparse_attention( ...@@ -207,6 +207,7 @@ def bigbird_block_sparse_attention(
n = to_seq_length n = to_seq_length
wm = from_block_size wm = from_block_size
wn = to_block_size wn = to_block_size
dtype = query_layer.dtype
query_layer = tf.transpose(query_layer, perm=[0, 2, 1, 3]) query_layer = tf.transpose(query_layer, perm=[0, 2, 1, 3])
key_layer = tf.transpose(key_layer, perm=[0, 2, 1, 3]) key_layer = tf.transpose(key_layer, perm=[0, 2, 1, 3])
value_layer = tf.transpose(value_layer, perm=[0, 2, 1, 3]) value_layer = tf.transpose(value_layer, perm=[0, 2, 1, 3])
...@@ -224,7 +225,7 @@ def bigbird_block_sparse_attention( ...@@ -224,7 +225,7 @@ def bigbird_block_sparse_attention(
"BHQD,BHKD->BHQK", blocked_query_matrix[:, :, 0], "BHQD,BHKD->BHQK", blocked_query_matrix[:, :, 0],
key_layer) # [b, h, wm, -1] x [b, h, n, -1] ==> [b, h, wm, n] key_layer) # [b, h, wm, -1] x [b, h, n, -1] ==> [b, h, wm, n]
first_product = tf.multiply(first_product, 1.0 / np.sqrt(d)) first_product = tf.multiply(first_product, 1.0 / np.sqrt(d))
first_product += (1.0 - tf.cast(to_mask, dtype=tf.float32)) * -10000.0 first_product += (1.0 - tf.cast(to_mask, dtype=dtype)) * -10000.0
first_attn_weights = tf.nn.softmax(first_product) # [b, h, wm, n] first_attn_weights = tf.nn.softmax(first_product) # [b, h, wm, n]
first_context_layer = tf.einsum( first_context_layer = tf.einsum(
"BHQK,BHKD->BHQD", first_attn_weights, "BHQK,BHKD->BHQD", first_attn_weights,
...@@ -246,10 +247,11 @@ def bigbird_block_sparse_attention( ...@@ -246,10 +247,11 @@ def bigbird_block_sparse_attention(
) # [b, h, wm, -1] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, (4+r)*wn] ) # [b, h, wm, -1] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, (4+r)*wn]
second_seq_pad = tf.concat([ second_seq_pad = tf.concat([
to_mask[:, :, :, :3 * wn], to_mask[:, :, :, -wn:], to_mask[:, :, :, :3 * wn], to_mask[:, :, :, -wn:],
tf.ones([b, 1, 1, r * wn], dtype=tf.float32) tf.ones([b, 1, 1, r * wn], dtype=dtype)
], 3)
second_rand_pad = tf.concat([
tf.ones([b, h, wm, 4 * wn], dtype=dtype), rand_mask[:, :, 0]
], 3) ], 3)
second_rand_pad = tf.concat(
[tf.ones([b, h, wm, 4 * wn], dtype=tf.float32), rand_mask[:, :, 0]], 3)
second_product = tf.multiply(second_product, 1.0 / np.sqrt(d)) second_product = tf.multiply(second_product, 1.0 / np.sqrt(d))
second_product += (1.0 - second_product += (1.0 -
tf.minimum(second_seq_pad, second_rand_pad)) * -10000.0 tf.minimum(second_seq_pad, second_rand_pad)) * -10000.0
...@@ -332,10 +334,10 @@ def bigbird_block_sparse_attention( ...@@ -332,10 +334,10 @@ def bigbird_block_sparse_attention(
) # [b, h, wm, -1] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, (4+r)*wn] ) # [b, h, wm, -1] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, (4+r)*wn]
second_last_seq_pad = tf.concat([ second_last_seq_pad = tf.concat([
to_mask[:, :, :, :wn], to_mask[:, :, :, -3 * wn:], to_mask[:, :, :, :wn], to_mask[:, :, :, -3 * wn:],
tf.ones([b, 1, 1, r * wn], dtype=tf.float32) tf.ones([b, 1, 1, r * wn], dtype=dtype)
], 3) ], 3)
second_last_rand_pad = tf.concat( second_last_rand_pad = tf.concat(
[tf.ones([b, h, wm, 4 * wn], dtype=tf.float32), rand_mask[:, :, -1]], 3) [tf.ones([b, h, wm, 4 * wn], dtype=dtype), rand_mask[:, :, -1]], 3)
second_last_product = tf.multiply(second_last_product, 1.0 / np.sqrt(d)) second_last_product = tf.multiply(second_last_product, 1.0 / np.sqrt(d))
second_last_product += ( second_last_product += (
1.0 - tf.minimum(second_last_seq_pad, second_last_rand_pad)) * -10000.0 1.0 - tf.minimum(second_last_seq_pad, second_last_rand_pad)) * -10000.0
...@@ -376,8 +378,7 @@ class BigBirdMasks(tf.keras.layers.Layer): ...@@ -376,8 +378,7 @@ class BigBirdMasks(tf.keras.layers.Layer):
def call(self, inputs): def call(self, inputs):
encoder_shape = tf.shape(inputs) encoder_shape = tf.shape(inputs)
batch_size, seq_length = encoder_shape[0], encoder_shape[1] batch_size, seq_length = encoder_shape[0], encoder_shape[1]
# reshape and cast for blocking # reshape for blocking
inputs = tf.cast(inputs, dtype=tf.float32)
blocked_encoder_mask = tf.reshape( blocked_encoder_mask = tf.reshape(
inputs, (batch_size, seq_length // self._block_size, self._block_size)) inputs, (batch_size, seq_length // self._block_size, self._block_size))
encoder_from_mask = tf.reshape(inputs, (batch_size, 1, seq_length, 1)) encoder_from_mask = tf.reshape(inputs, (batch_size, 1, seq_length, 1))
......
...@@ -29,7 +29,7 @@ class BigbirdAttentionTest(tf.test.TestCase): ...@@ -29,7 +29,7 @@ class BigbirdAttentionTest(tf.test.TestCase):
block_size = 64 block_size = 64
mask_layer = attention.BigBirdMasks(block_size=block_size) mask_layer = attention.BigBirdMasks(block_size=block_size)
encoder_inputs_mask = tf.zeros((batch_size, seq_length), dtype=tf.int32) encoder_inputs_mask = tf.zeros((batch_size, seq_length), dtype=tf.int32)
masks = mask_layer(encoder_inputs_mask) masks = mask_layer(tf.cast(encoder_inputs_mask, dtype=tf.float64))
test_layer = attention.BigBirdAttention( test_layer = attention.BigBirdAttention(
num_heads=num_heads, num_heads=num_heads,
key_dim=key_dim, key_dim=key_dim,
......
...@@ -142,7 +142,8 @@ class BigBirdEncoder(tf.keras.Model): ...@@ -142,7 +142,8 @@ class BigBirdEncoder(tf.keras.Model):
self._transformer_layers = [] self._transformer_layers = []
data = embeddings data = embeddings
masks = attention.BigBirdMasks(block_size=block_size)(mask) masks = attention.BigBirdMasks(block_size=block_size)(
tf.cast(mask, embeddings.dtype))
encoder_outputs = [] encoder_outputs = []
attn_head_dim = hidden_size // num_attention_heads attn_head_dim = hidden_size // num_attention_heads
for i in range(num_layers): for i in range(num_layers):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment