Commit 09e6e71c authored by Zihan Wang's avatar Zihan Wang
Browse files

lint

parent 32867f40
...@@ -18,27 +18,20 @@ Longformer attention block. Modified From huggingface/transformers ...@@ -18,27 +18,20 @@ Longformer attention block. Modified From huggingface/transformers
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
import collections
import math import math
import string import string
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from keras import constraints
from keras import initializers
from keras import regularizers
from keras.engine.base_layer import Layer
from keras.layers import core from keras.layers import core
from keras.layers import einsum_dense from keras.layers import einsum_dense
from keras.utils import tf_utils from keras.utils import tf_utils
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export
from official.modeling.tf_utils import get_shape_list from official.modeling.tf_utils import get_shape_list
from typing import Dict, List, Optional, Union
_CHR_IDX = string.ascii_lowercase _CHR_IDX = string.ascii_lowercase
def _build_attention_equation(rank, attn_axes): def _build_attention_equation(rank, attn_axes):
"""Builds einsum equations for the attention computation. """Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as: Query, key, value inputs after projection are expected to have the shape as:
...@@ -64,7 +57,7 @@ def _build_attention_equation(rank, attn_axes): ...@@ -64,7 +57,7 @@ def _build_attention_equation(rank, attn_axes):
# `batch_dims` includes the head dim. # `batch_dims` includes the head dim.
batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,))) batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
letter_offset = rank letter_offset = rank
source_notation = "" source_notation = ''
for i in range(rank): for i in range(rank):
if i in batch_dims or i == rank - 1: if i in batch_dims or i == rank - 1:
source_notation += target_notation[i] source_notation += target_notation[i]
...@@ -72,23 +65,21 @@ def _build_attention_equation(rank, attn_axes): ...@@ -72,23 +65,21 @@ def _build_attention_equation(rank, attn_axes):
source_notation += _CHR_IDX[letter_offset] source_notation += _CHR_IDX[letter_offset]
letter_offset += 1 letter_offset += 1
product_notation = "".join([target_notation[i] for i in batch_dims] + product_notation = ''.join([target_notation[i] for i in batch_dims] +
[target_notation[i] for i in attn_axes] + [target_notation[i] for i in attn_axes] +
[source_notation[i] for i in attn_axes]) [source_notation[i] for i in attn_axes])
dot_product_equation = "%s,%s->%s" % (source_notation, target_notation, dot_product_equation = f'{source_notation},{target_notation}->{product_notation}'
product_notation)
attn_scores_rank = len(product_notation) attn_scores_rank = len(product_notation)
combine_equation = "%s,%s->%s" % (product_notation, source_notation, combine_equation = f'{product_notation},{source_notation}->{target_notation}'
target_notation)
return dot_product_equation, combine_equation, attn_scores_rank return dot_product_equation, combine_equation, attn_scores_rank
def _build_proj_equation(free_dims, bound_dims, output_dims): def _build_proj_equation(free_dims, bound_dims, output_dims):
"""Builds an einsum equation for projections inside multi-head attention.""" """Builds an einsum equation for projections inside multi-head attention."""
input_str = "" input_str = ''
kernel_str = "" kernel_str = ''
output_str = "" output_str = ''
bias_axes = "" bias_axes = ''
letter_offset = 0 letter_offset = 0
for i in range(free_dims): for i in range(free_dims):
char = _CHR_IDX[i + letter_offset] char = _CHR_IDX[i + letter_offset]
...@@ -107,7 +98,7 @@ def _build_proj_equation(free_dims, bound_dims, output_dims): ...@@ -107,7 +98,7 @@ def _build_proj_equation(free_dims, bound_dims, output_dims):
kernel_str += char kernel_str += char
output_str += char output_str += char
bias_axes += char bias_axes += char
equation = "%s,%s->%s" % (input_str, kernel_str, output_str) equation = f'{input_str},{kernel_str}->{output_str}'
return equation, bias_axes, len(output_str) return equation, bias_axes, len(output_str)
...@@ -115,8 +106,17 @@ def _build_proj_equation(free_dims, bound_dims, output_dims): ...@@ -115,8 +106,17 @@ def _build_proj_equation(free_dims, bound_dims, output_dims):
def _get_output_shape(output_rank, known_last_dims): def _get_output_shape(output_rank, known_last_dims):
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims) return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
class LongformerAttention(tf.keras.layers.MultiHeadAttention): class LongformerAttention(tf.keras.layers.MultiHeadAttention):
"""LongformerAttention
Args:
attention_window: int representing the window size for attention.
layer_id: int of the id of the layer.
global_attention_size: the size of global attention used for each token.
"""
def __init__(self, def __init__(self,
attention_window, attention_window,
layer_id, layer_id,
...@@ -124,14 +124,16 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -124,14 +124,16 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._layer_id = layer_id self._layer_id = layer_id
_attention_window = attention_window self._attention_window = attention_window
assert ( assert (
_attention_window % 2 == 0 self._attention_window % 2 == 0
), f"`attention_window` for layer {self._layer_id} has to be an even value. Given {attention_window}" ), f"`attention_window` for layer {self._layer_id} has to be an even " \
f"value. Given {self.attention_window}"
assert ( assert (
_attention_window > 0 self._attention_window > 0
), f"`attention_window` for layer {self._layer_id} has to be positive. Given {attention_window}" ), f"`attention_window` for layer {self._layer_id} has to be positive. " \
self._one_sided_attn_window_size = _attention_window // 2 f"Given {self.attention_window}"
self._one_sided_attn_window_size = self._attention_window // 2
self.global_attention_size = global_attention_size self.global_attention_size = global_attention_size
def _build_from_signature(self, query, value, key=None): def _build_from_signature(self, query, value, key=None):
...@@ -237,7 +239,6 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -237,7 +239,6 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
attention_mask=None, attention_mask=None,
is_index_masked=None, is_index_masked=None,
is_index_global_attn=None, is_index_global_attn=None,
is_global_attn=None,
training=None): training=None):
"""Applies Dot-product attention with query, key, value tensors. """Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected This function defines the computation inside `call` with projected
...@@ -256,7 +257,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -256,7 +257,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
attention_scores: Multi-headed attention weights. attention_scores: Multi-headed attention weights.
""" """
if not self._built_from_signature: if not self._built_from_signature:
self._build_from_signature(query=hidden_states, value=hidden_states, key=hidden_states) self._build_from_signature(query=hidden_states, value=hidden_states,
key=hidden_states)
# N = `num_attention_heads` # N = `num_attention_heads`
# H = `size_per_head` # H = `size_per_head`
...@@ -272,7 +274,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -272,7 +274,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# Note: Applying scalar multiply at the smaller end of einsum improves # Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in # XLA performance, but may introduce slight numeric differences in
# the Transformer attention head. # the Transformer attention head.
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim))) # (B, T, N, key_dim) query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim)))
batch_size, seq_len, num_heads, head_dim = get_shape_list(query) batch_size, seq_len, num_heads, head_dim = get_shape_list(query)
# attn_probs = (batch_size, seq_len, num_heads, window*2+1) # attn_probs = (batch_size, seq_len, num_heads, window*2+1)
...@@ -293,8 +295,12 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -293,8 +295,12 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if tf.executing_eagerly(): if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
get_shape_list(attn_scores), get_shape_list(attn_scores),
[batch_size, seq_len, self._num_heads, self._one_sided_attn_window_size * 2 + 1], [batch_size, seq_len, self._num_heads,
message=f"attn_probs should be of size ({batch_size}, {seq_len}, {num_heads}, {self._one_sided_attn_window_size * 2 + 1}), but is of size {get_shape_list(attn_scores)}", self._one_sided_attn_window_size * 2 + 1],
message=f"attn_probs should be of size "
f"({batch_size}, {seq_len}, {num_heads}, "
f"{self._one_sided_attn_window_size * 2 + 1}),"
f" but is of size {get_shape_list(attn_scores)}",
) )
# compute global attn indices required through out forward fn # compute global attn indices required through out forward fn
...@@ -303,7 +309,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -303,7 +309,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
is_index_global_attn_nonzero, is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero, is_local_index_no_global_attn_nonzero,
) = self._get_global_attn_indices(is_index_global_attn, self.global_attention_size) ) = self._get_global_attn_indices(is_index_global_attn,
self.global_attention_size)
# this function is only relevant for global attention # this function is only relevant for global attention
if self.global_attention_size > 0: if self.global_attention_size > 0:
attn_scores = self._concat_with_global_key_attn_probs( attn_scores = self._concat_with_global_key_attn_probs(
...@@ -320,14 +327,18 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -320,14 +327,18 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
attn_probs = tf.nn.softmax(attn_scores, axis=-1) attn_probs = tf.nn.softmax(attn_scores, axis=-1)
# softmax sometimes inserts NaN if all positions are masked, replace them with 0 # softmax sometimes inserts NaN if all positions are masked,
# replace them with 0
# Make sure to create a mask with the proper shape: # Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] # if is_global_attn==True => [batch_size, seq_len, self.num_heads,
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] # self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads,
# self.one_sided_attn_window_size * 2 + 1]
if self.global_attention_size > 0: if self.global_attention_size > 0:
masked_index = tf.tile( masked_index = tf.tile(
is_index_masked[:, :, None, None], is_index_masked[:, :, None, None],
(1, 1, self._num_heads, self._one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), (1, 1, self._num_heads,
self._one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
) )
else: else:
masked_index = tf.tile( masked_index = tf.tile(
...@@ -347,14 +358,17 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -347,14 +358,17 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
tf.debugging.assert_equal( tf.debugging.assert_equal(
get_shape_list(layer_head_mask), get_shape_list(layer_head_mask),
[self._num_heads], [self._num_heads],
message=f"Head mask for a single layer should be of size {(self._num_heads)}, but is {get_shape_list(layer_head_mask)}", message=f"Head mask for a single layer should be of size "
f"{(self._num_heads)}, but is "
f"{get_shape_list(layer_head_mask)}",
) )
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
# apply dropout # apply dropout
attn_probs = self._dropout_layer(attn_probs, training=training) attn_probs = self._dropout_layer(attn_probs, training=training)
value_vectors = tf.reshape(value, (batch_size, seq_len, self._num_heads, self._key_dim)) # TODO: _key_dim == _value_dim value_vectors = tf.reshape(value, (batch_size, seq_len, self._num_heads,
self._key_dim))
# if global attention, compute sum of global and local attn # if global attention, compute sum of global and local attn
if self.global_attention_size > 0: if self.global_attention_size > 0:
...@@ -377,12 +391,14 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -377,12 +391,14 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
message="Unexpected size", message="Unexpected size",
) )
attn_output = tf.reshape(attn_output, (batch_size, seq_len, self._num_heads * self._key_dim)) # FIXME attn_output = tf.reshape(attn_output, (
batch_size, seq_len, self._num_heads * self._key_dim)) # FIXME
# compute value for global attention and overwrite to attention output # compute value for global attention and overwrite to attention output
# TODO: remove the redundant computation # TODO: remove the redundant computation
if self.global_attention_size > 0: if self.global_attention_size > 0:
attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden( attn_output, global_attn_probs = \
self._compute_global_attn_output_from_hidden(
attn_output=attn_output, attn_output=attn_output,
hidden_states=hidden_states, hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices, max_num_global_attn_indices=max_num_global_attn_indices,
...@@ -394,16 +410,16 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -394,16 +410,16 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
training=training, training=training,
) )
else: else:
global_attn_probs = tf.zeros((batch_size, self._num_heads, max_num_global_attn_indices, seq_len)) global_attn_probs = tf.zeros(
(batch_size, self._num_heads, max_num_global_attn_indices, seq_len))
# make sure that local attention probabilities are set to 0 for indices of global attn # make sure that local attention probabilities are set to 0 for indices of
# Make sure to create a mask with the proper shape: # global attn
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
if self.global_attention_size > 0: if self.global_attention_size > 0:
masked_global_attn_index = tf.tile( masked_global_attn_index = tf.tile(
is_index_global_attn[:, :, None, None], is_index_global_attn[:, :, None, None],
(1, 1, self._num_heads, self._one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), (1, 1, self._num_heads,
self._one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
) )
else: else:
masked_global_attn_index = tf.tile( masked_global_attn_index = tf.tile(
...@@ -413,7 +429,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -413,7 +429,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
attn_probs = tf.where( attn_probs = tf.where(
masked_global_attn_index, masked_global_attn_index,
tf.zeros(get_shape_list(masked_global_attn_index), dtype=attn_probs.dtype), tf.zeros(get_shape_list(masked_global_attn_index),
dtype=attn_probs.dtype),
attn_probs, attn_probs,
) )
...@@ -432,9 +449,10 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -432,9 +449,10 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
""" """
Matrix multiplication of query and key tensors using with a sliding window attention pattern. This Matrix multiplication of query and key tensors using with a sliding window
implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an attention pattern. This implementation splits the input into overlapping
overlap of size window_overlap chunks of size 2w (e.g. 512 for pretrained Longformer) with an overlap of
size window_overlap
""" """
batch_size, seq_len, num_heads, head_dim = get_shape_list(query) batch_size, seq_len, num_heads, head_dim = get_shape_list(query)
...@@ -442,22 +460,26 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -442,22 +460,26 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
tf.debugging.assert_equal( tf.debugging.assert_equal(
seq_len % (window_overlap * 2), seq_len % (window_overlap * 2),
0, 0,
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}", message=f"Sequence length should be multiple of {window_overlap * 2}. "
f"Given {seq_len}",
) )
tf.debugging.assert_equal( tf.debugging.assert_equal(
get_shape_list(query), get_shape_list(query),
get_shape_list(key), get_shape_list(key),
message=f"Shape of query and key should be equal, but got query: {get_shape_list(query)} and key: {get_shape_list(key)}", message=f"Shape of query and key should be equal, but got query: "
f"{get_shape_list(query)} and key: {get_shape_list(key)}",
) )
chunks_count = seq_len // window_overlap - 1 chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 # group batch_size and num_heads dimensions into one,
# then chunk seq_len into chunks of size window_overlap * 2
query = tf.reshape( query = tf.reshape(
tf.transpose(query, (0, 2, 1, 3)), tf.transpose(query, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim), (batch_size * num_heads, seq_len, head_dim),
) )
key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim))
chunked_query = self._chunk(query, window_overlap) chunked_query = self._chunk(query, window_overlap)
chunked_key = self._chunk(key, window_overlap) chunked_key = self._chunk(key, window_overlap)
...@@ -466,24 +488,31 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -466,24 +488,31 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
chunked_query = tf.cast(chunked_query, dtype=chunked_key.dtype) chunked_query = tf.cast(chunked_query, dtype=chunked_key.dtype)
chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query,
chunked_key) # multiply
# convert diagonals into columns # convert diagonals into columns
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]]) paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]])
diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings) diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(
chunked_attention_scores, paddings)
# allocate space for the overall attention matrix where the chunks are combined. The last dimension
# has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to # allocate space for the overall attention matrix where the chunks are
# window_overlap previous words). The following column is attention score from each word to itself, then # combined. The last dimension
# has (window_overlap * 2 + 1) columns. The first (window_overlap) columns
# are the window_overlap lower triangles (attention from a word to
# window_overlap previous words). The following column is attention score
# from each word to itself, then
# followed by window_overlap columns for the upper triangle. # followed by window_overlap columns for the upper triangle.
# copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions # copy parts from diagonal_chunked_attention_scores into the combined matrix
# - copying the main diagonal and the upper triangle # of attentions - copying the main diagonal and the upper triangle
# TODO: This code is most likely not very efficient and should be improved # TODO: This code is most likely not very efficient and should be improved
diagonal_attn_scores_up_triang = tf.concat( diagonal_attn_scores_up_triang = tf.concat(
[ [
diagonal_chunked_attention_scores[:, :, :window_overlap, : window_overlap + 1], diagonal_chunked_attention_scores[:, :, :window_overlap,
diagonal_chunked_attention_scores[:, -1:, window_overlap:, : window_overlap + 1], : window_overlap + 1],
diagonal_chunked_attention_scores[:, -1:, window_overlap:,
: window_overlap + 1],
], ],
axis=1, axis=1,
) )
...@@ -495,7 +524,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -495,7 +524,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
(batch_size * num_heads, 1, window_overlap, window_overlap), (batch_size * num_heads, 1, window_overlap, window_overlap),
dtype=diagonal_chunked_attention_scores.dtype, dtype=diagonal_chunked_attention_scores.dtype,
), ),
diagonal_chunked_attention_scores[:, :, -(window_overlap + 1): -1, window_overlap + 1:], diagonal_chunked_attention_scores[:, :, -(window_overlap + 1): -1,
window_overlap + 1:],
], ],
axis=1, axis=1,
) )
...@@ -520,7 +550,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -520,7 +550,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
) )
< 1 < 1
) )
#first_chunk_mask = tf.repeat(first_chunk_mask, batch_size * num_heads, axis=0)
diagonal_attn_scores_low_triang = tf.where( diagonal_attn_scores_low_triang = tf.where(
first_chunk_mask, first_chunk_mask,
diagonal_attn_scores_first_chunk, diagonal_attn_scores_first_chunk,
...@@ -541,7 +571,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -541,7 +571,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
(0, 2, 1, 3), (0, 2, 1, 3),
) )
diagonal_attention_scores = self._mask_invalid_locations(diagonal_attention_scores, window_overlap) diagonal_attention_scores = self._mask_invalid_locations(
diagonal_attention_scores, window_overlap)
return diagonal_attention_scores return diagonal_attention_scores
...@@ -549,13 +580,15 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -549,13 +580,15 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
def _mask_invalid_locations(input_tensor, window_overlap): def _mask_invalid_locations(input_tensor, window_overlap):
# create correct upper triangle bool mask # create correct upper triangle bool mask
mask_2d_upper = tf.reverse( mask_2d_upper = tf.reverse(
tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)),
-1, 0),
axis=[0], axis=[0],
) )
# pad to full matrix # pad to full matrix
padding = tf.convert_to_tensor( padding = tf.convert_to_tensor(
[[0, get_shape_list(input_tensor)[1] - window_overlap], [0, get_shape_list(input_tensor)[3] - window_overlap - 1]] [[0, get_shape_list(input_tensor)[1] - window_overlap],
[0, get_shape_list(input_tensor)[3] - window_overlap - 1]]
) )
# create lower mask # create lower mask
...@@ -565,20 +598,23 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -565,20 +598,23 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1]) mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
# broadcast to full matrix # broadcast to full matrix
mask_4d = tf.tile(mask_2d[None, :, None, :], (get_shape_list(input_tensor)[0], 1, 1, 1)) mask_4d = tf.tile(mask_2d[None, :, None, :],
(get_shape_list(input_tensor)[0], 1, 1, 1))
# inf tensor used for masking # inf tensor used for masking
inf_tensor = -float("inf") * tf.ones_like(input_tensor) inf_tensor = -float("inf") * tf.ones_like(input_tensor)
# mask # mask
input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor) input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor,
input_tensor)
return input_tensor return input_tensor
def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_overlap): def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value,
window_overlap):
""" """
Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors.
same shape as `attn_probs` Returned tensor will be of the same shape as `attn_probs`
""" """
batch_size, seq_len, num_heads, head_dim = get_shape_list(value) batch_size, seq_len, num_heads, head_dim = get_shape_list(value)
...@@ -602,7 +638,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -602,7 +638,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
chunks_count = seq_len // window_overlap - 1 chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap # group batch_size and num_heads dimensions into one, then chunk seq_len
# into chunks of size 2 window overlap
chunked_attn_probs = tf.reshape( chunked_attn_probs = tf.reshape(
tf.transpose(attn_probs, (0, 2, 1, 3)), tf.transpose(attn_probs, (0, 2, 1, 3)),
( (
...@@ -619,13 +656,17 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -619,13 +656,17 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
(batch_size * num_heads, seq_len, head_dim), (batch_size * num_heads, seq_len, head_dim),
) )
# pad seq_len with w at the beginning of the sequence and another window overlap at the end # pad seq_len with w at the beginning of the sequence and another window
paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]]) # overlap at the end
paddings = tf.convert_to_tensor(
[[0, 0], [window_overlap, window_overlap], [0, 0]])
padded_value = tf.pad(value, paddings, constant_values=-1) padded_value = tf.pad(value, paddings, constant_values=-1)
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap # chunk padded_value into chunks of size 3 window overlap and an overlap of
# size window overlap
frame_size = 3 * window_overlap * head_dim frame_size = 3 * window_overlap * head_dim
frame_hop_size = (get_shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count frame_hop_size = (get_shape_list(padded_value)[
1] * head_dim - frame_size) // chunks_count
chunked_value = tf.signal.frame( chunked_value = tf.signal.frame(
tf.reshape(padded_value, (batch_size * num_heads, -1)), tf.reshape(padded_value, (batch_size * num_heads, -1)),
frame_size, frame_size,
...@@ -639,7 +680,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -639,7 +680,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if tf.executing_eagerly(): if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
get_shape_list(chunked_value), get_shape_list(chunked_value),
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim], [batch_size * num_heads, chunks_count + 1, 3 * window_overlap,
head_dim],
message="Chunked value has the wrong shape", message="Chunked value has the wrong shape",
) )
...@@ -658,8 +700,10 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -658,8 +700,10 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
hidden_states_padded = tf.pad( hidden_states_padded = tf.pad(
hidden_states_padded, paddings hidden_states_padded, paddings
) # padding value is not important because it will be overwritten ) # padding value is not important because it will be overwritten
batch_size, chunk_size, seq_length, hidden_dim = get_shape_list(hidden_states_padded) batch_size, chunk_size, seq_length, hidden_dim = get_shape_list(
hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length)) hidden_states_padded)
hidden_states_padded = tf.reshape(hidden_states_padded, (
batch_size, chunk_size, hidden_dim, seq_length))
return hidden_states_padded return hidden_states_padded
...@@ -681,21 +725,27 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -681,21 +725,27 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000 0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
""" """
total_num_heads, num_chunks, window_overlap, hidden_dim = get_shape_list(chunked_hidden_states) total_num_heads, num_chunks, window_overlap, hidden_dim = get_shape_list(
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]]) chunked_hidden_states)
paddings = tf.convert_to_tensor(
[[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]])
chunked_hidden_states = tf.pad( chunked_hidden_states = tf.pad(
chunked_hidden_states, paddings chunked_hidden_states, paddings
) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten )
chunked_hidden_states = tf.reshape( chunked_hidden_states = tf.reshape(
chunked_hidden_states, (total_num_heads, num_chunks, -1) chunked_hidden_states, (total_num_heads, num_chunks, -1)
) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap )
chunked_hidden_states = chunked_hidden_states[ chunked_hidden_states = chunked_hidden_states[
:, :, :-window_overlap :, :, :-window_overlap
] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap ]
chunked_hidden_states = tf.reshape( chunked_hidden_states = tf.reshape(
chunked_hidden_states, chunked_hidden_states,
(total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim), (
) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap total_num_heads, num_chunks, window_overlap,
window_overlap + hidden_dim),
)
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
return chunked_hidden_states return chunked_hidden_states
...@@ -709,16 +759,21 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -709,16 +759,21 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# define frame size and frame stride (similar to convolution) # define frame size and frame stride (similar to convolution)
frame_hop_size = window_overlap * hidden_dim frame_hop_size = window_overlap * hidden_dim
frame_size = 2 * frame_hop_size frame_size = 2 * frame_hop_size
hidden_states = tf.reshape(hidden_states, (batch_size, seq_length * hidden_dim)) hidden_states = tf.reshape(hidden_states,
(batch_size, seq_length * hidden_dim))
# chunk with overlap # chunk with overlap
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size) chunked_hidden_states = tf.signal.frame(hidden_states, frame_size,
frame_hop_size)
if tf.executing_eagerly(): if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
get_shape_list(chunked_hidden_states), get_shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size], [batch_size, num_output_chunks, frame_size],
message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size, frame_size, num_output_chunks]}, but got {get_shape_list(chunked_hidden_states)}.", message=f"Make sure chunking is correctly applied. `Chunked hidden "
f"states should have output dimension"
f" {[batch_size, frame_size, num_output_chunks]}, but got "
f"{get_shape_list(chunked_hidden_states)}.",
) )
chunked_hidden_states = tf.reshape( chunked_hidden_states = tf.reshape(
...@@ -738,19 +793,25 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -738,19 +793,25 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
max_num_global_attn_indices = global_attention_size max_num_global_attn_indices = global_attention_size
row_indices = tf.range(batch_size) row_indices = tf.range(batch_size)
row_indices = tf.repeat(tf.expand_dims(row_indices, axis=0), repeats=[global_attention_size], axis=0) row_indices = tf.repeat(tf.expand_dims(row_indices, axis=0),
row_indices = tf.reshape(row_indices, (batch_size * global_attention_size, 1)) repeats=[global_attention_size], axis=0)
row_indices = tf.reshape(row_indices,
(batch_size * global_attention_size, 1))
col_indices = tf.range(global_attention_size) col_indices = tf.range(global_attention_size)
col_indices = tf.repeat(tf.expand_dims(col_indices, axis=1), repeats=[batch_size], axis=0) col_indices = tf.repeat(tf.expand_dims(col_indices, axis=1),
repeats=[batch_size], axis=0)
is_index_global_attn_nonzero = tf.concat((row_indices, col_indices), axis=1) is_index_global_attn_nonzero = tf.concat((row_indices, col_indices), axis=1)
# this is actually same as `is_index_global_attn_nonzero`, since we assume all global attention are the same size # this is actually same as `is_index_global_attn_nonzero`,
is_local_index_global_attn_nonzero = tf.concat((row_indices, col_indices), axis=1) # since we assume all global attention are the same size
is_local_index_global_attn_nonzero = tf.concat((row_indices, col_indices),
axis=1)
# empty tensor # empty tensor
is_local_index_no_global_attn_nonzero = tf.reshape(tf.expand_dims(tf.range(0), axis=1), (0, 2)) is_local_index_no_global_attn_nonzero = tf.reshape(
tf.expand_dims(tf.range(0), axis=1), (0, 2))
return ( return (
max_num_global_attn_indices, max_num_global_attn_indices,
is_index_global_attn_nonzero, is_index_global_attn_nonzero,
...@@ -786,11 +847,14 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -786,11 +847,14 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
) )
# (batch_size, seq_len, num_heads, max_num_global_attn_indices) # (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors, key_vectors_only_global) attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors,
key_vectors_only_global)
# (batch_size, max_num_global_attn_indices, seq_len, num_heads) # (batch_size, max_num_global_attn_indices, seq_len, num_heads)
attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2)) attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key,
mask_shape = (get_shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple( (0, 3, 1, 2))
mask_shape = (get_shape_list(is_local_index_no_global_attn_nonzero)[
0],) + tuple(
get_shape_list(attn_probs_from_global_key_trans)[-2:] get_shape_list(attn_probs_from_global_key_trans)[-2:]
) )
mask = tf.ones(mask_shape) * -10000.0 mask = tf.ones(mask_shape) * -10000.0
...@@ -804,7 +868,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -804,7 +868,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
) )
# (batch_size, seq_len, num_heads, max_num_global_attn_indices) # (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key = tf.transpose(attn_probs_from_global_key_trans, (0, 2, 3, 1)) attn_probs_from_global_key = tf.transpose(attn_probs_from_global_key_trans,
(0, 2, 3, 1))
# concat to attn_probs # concat to attn_probs
# (batch_size, seq_len, num_heads, extra attention count + 2*window+1) # (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
...@@ -823,10 +888,10 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -823,10 +888,10 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# cut local attn probs to global only # cut local attn probs to global only
attn_probs_only_global = attn_probs[:, :, :, :max_num_global_attn_indices] attn_probs_only_global = attn_probs[:, :, :, :max_num_global_attn_indices]
# attn_probs_only_global = tf.slice(attn_probs, [0, 0, 0, 0], get_shape_list(attn_probs)[: -1] + [max_num_global_attn_indices])
# select global value vectors # select global value vectors
global_value_vectors = tf.gather_nd(value_vectors, is_index_global_attn_nonzero) global_value_vectors = tf.gather_nd(value_vectors,
is_index_global_attn_nonzero)
# create only global value vectors # create only global value vectors
value_vectors_only_global = tf.scatter_nd( value_vectors_only_global = tf.scatter_nd(
...@@ -841,10 +906,12 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -841,10 +906,12 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
) )
# compute attn output only global # compute attn output only global
attn_output_only_global = tf.einsum("blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global) attn_output_only_global = tf.einsum("blhs,bshd->blhd",
attn_probs_only_global,
value_vectors_only_global)
# reshape attn probs # reshape attn probs
attn_probs_without_global = attn_probs[:, :, :, max_num_global_attn_indices:] attn_probs_without_global = attn_probs[:, :, :,
# attn_probs_without_global = tf.slice(attn_probs, [0, 0, 0, max_num_global_attn_indices], get_shape_list(attn_probs)[: -1] + [get_shape_list(attn_probs)[-1] - max_num_global_attn_indices]) max_num_global_attn_indices:]
# compute attn output with global # compute attn output with global
attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value( attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
...@@ -868,15 +935,19 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -868,15 +935,19 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
batch_size, seq_len = get_shape_list(hidden_states)[:2] batch_size, seq_len = get_shape_list(hidden_states)[:2]
# prepare global hidden states # prepare global hidden states
global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero) global_attn_hidden_states = tf.gather_nd(hidden_states,
is_index_global_attn_nonzero)
global_attn_hidden_states = tf.scatter_nd( global_attn_hidden_states = tf.scatter_nd(
is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero,
global_attn_hidden_states, global_attn_hidden_states,
shape=(batch_size, max_num_global_attn_indices, self._num_heads * self._key_dim), shape=(
batch_size, max_num_global_attn_indices,
self._num_heads * self._key_dim),
) )
# global key, query, value # global key, query, value
global_query_vectors_only_global = self._global_query_dense(global_attn_hidden_states) global_query_vectors_only_global = self._global_query_dense(
global_attn_hidden_states)
global_key_vectors = self._global_key_dense(hidden_states) global_key_vectors = self._global_key_dense(hidden_states)
global_value_vectors = self._global_value_dense(hidden_states) global_value_vectors = self._global_value_dense(hidden_states)
...@@ -884,18 +955,24 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -884,18 +955,24 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
global_query_vectors_only_global /= tf.math.sqrt( global_query_vectors_only_global /= tf.math.sqrt(
tf.cast(self._key_dim, dtype=global_query_vectors_only_global.dtype) tf.cast(self._key_dim, dtype=global_query_vectors_only_global.dtype)
) )
global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size) global_query_vectors_only_global = self.reshape_and_transpose(
global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size) global_query_vectors_only_global, batch_size)
global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size) global_key_vectors = self.reshape_and_transpose(global_key_vectors,
batch_size)
global_value_vectors = self.reshape_and_transpose(global_value_vectors,
batch_size)
# compute attn scores # compute attn scores
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True) global_attn_scores = tf.matmul(global_query_vectors_only_global,
global_key_vectors, transpose_b=True)
if tf.executing_eagerly(): if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
get_shape_list(global_attn_scores), get_shape_list(global_attn_scores),
[batch_size * self._num_heads, max_num_global_attn_indices, seq_len], [batch_size * self._num_heads, max_num_global_attn_indices, seq_len],
message=f"global_attn_scores have the wrong size. Size should be {(batch_size * self._num_heads, max_num_global_attn_indices, seq_len)}, but is {get_shape_list(global_attn_scores)}.", message=f"global_attn_scores have the wrong size. Size should be"
f"{(batch_size * self._num_heads, max_num_global_attn_indices, seq_len)}, "
f"but is {get_shape_list(global_attn_scores)}.",
) )
global_attn_scores = tf.reshape( global_attn_scores = tf.reshape(
...@@ -903,11 +980,13 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -903,11 +980,13 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
(batch_size, self._num_heads, max_num_global_attn_indices, seq_len), (batch_size, self._num_heads, max_num_global_attn_indices, seq_len),
) )
global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3)) global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3))
mask_shape = (get_shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple( mask_shape = (get_shape_list(is_local_index_no_global_attn_nonzero)[
0],) + tuple(
get_shape_list(global_attn_scores_trans)[-2:] get_shape_list(global_attn_scores_trans)[-2:]
) )
global_attn_mask = tf.ones(mask_shape) * -10000.0 global_attn_mask = tf.ones(mask_shape) * -10000.0
global_attn_mask = tf.cast(global_attn_mask, dtype=global_attn_scores_trans.dtype) global_attn_mask = tf.cast(global_attn_mask,
dtype=global_attn_scores_trans.dtype)
# scatter mask # scatter mask
global_attn_scores_trans = tf.tensor_scatter_nd_update( global_attn_scores_trans = tf.tensor_scatter_nd_update(
...@@ -918,7 +997,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -918,7 +997,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3)) global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
# mask global attn scores # mask global attn scores
attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, get_shape_list(global_attn_scores)[1], 1, 1)) attn_mask = tf.tile(is_index_masked[:, None, None, :],
(1, get_shape_list(global_attn_scores)[1], 1, 1))
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores) global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
global_attn_scores = tf.reshape( global_attn_scores = tf.reshape(
global_attn_scores, global_attn_scores,
...@@ -934,17 +1014,22 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -934,17 +1014,22 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
tf.debugging.assert_equal( tf.debugging.assert_equal(
get_shape_list(layer_head_mask), get_shape_list(layer_head_mask),
[self._num_heads], [self._num_heads],
message=f"Head mask for a single layer should be of size {(self._num_heads)}, but is {get_shape_list(layer_head_mask)}", message=f"Head mask for a single layer should be of size "
f"{(self._num_heads)}, but is {get_shape_list(layer_head_mask)}",
) )
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( global_attn_probs_float = tf.reshape(layer_head_mask,
global_attn_probs_float, (batch_size, self._num_heads, max_num_global_attn_indices, seq_len) (1, -1, 1, 1)) * tf.reshape(
global_attn_probs_float,
(batch_size, self._num_heads, max_num_global_attn_indices, seq_len)
) )
global_attn_probs_float = tf.reshape( global_attn_probs_float = tf.reshape(
global_attn_probs_float, (batch_size * self._num_heads, max_num_global_attn_indices, seq_len) global_attn_probs_float,
(batch_size * self._num_heads, max_num_global_attn_indices, seq_len)
) )
# dropout # dropout
global_attn_probs = self._global_dropout_layer(global_attn_probs_float, training=training) global_attn_probs = self._global_dropout_layer(global_attn_probs_float,
training=training)
# global attn output # global attn output
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors) global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
...@@ -952,8 +1037,11 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -952,8 +1037,11 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if tf.executing_eagerly(): if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
get_shape_list(global_attn_output), get_shape_list(global_attn_output),
[batch_size * self._num_heads, max_num_global_attn_indices, self._key_dim], [batch_size * self._num_heads, max_num_global_attn_indices,
message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self._num_heads, max_num_global_attn_indices, self._key_dim)}, but is {get_shape_list(global_attn_output)}.", self._key_dim],
message=f"global_attn_output tensor has the wrong size. Size should be "
f"{(batch_size * self._num_heads, max_num_global_attn_indices, self._key_dim)}, "
f"but is {get_shape_list(global_attn_output)}.",
) )
global_attn_output = tf.reshape( global_attn_output = tf.reshape(
...@@ -977,7 +1065,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -977,7 +1065,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
) )
global_attn_probs = tf.reshape( global_attn_probs = tf.reshape(
global_attn_probs, (batch_size, self._num_heads, max_num_global_attn_indices, seq_len) global_attn_probs,
(batch_size, self._num_heads, max_num_global_attn_indices, seq_len)
) )
attn_output = self._output_dense(attn_output) attn_output = self._output_dense(attn_output)
......
...@@ -12,13 +12,14 @@ ...@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tests for the attention layer.""" """Tests for official.nlp.projects.longformer.longformer_attention."""
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import \
keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.projects.longformer import longformer_attention from official.projects.longformer import longformer_attention
from official.modeling.tf_utils import get_shape_list from official.modeling.tf_utils import get_shape_list
...@@ -56,7 +57,7 @@ def _create_mock_attention_data( ...@@ -56,7 +57,7 @@ def _create_mock_attention_data(
if include_mask: if include_mask:
mask_shape = (batch_size, num_heads, q_seq_length, total_seq_length) mask_shape = (batch_size, num_heads, q_seq_length, total_seq_length)
mask_data = np.random.randint(2, size=mask_shape).astype("float32") mask_data = np.random.randint(2, size=mask_shape).astype('float32')
mask_data = dict(attention_mask=mask_data) mask_data = dict(attention_mask=mask_data)
data.update(mask_data) data.update(mask_data)
...@@ -65,6 +66,12 @@ def _create_mock_attention_data( ...@@ -65,6 +66,12 @@ def _create_mock_attention_data(
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class LongformerAttentionTest(keras_parameterized.TestCase): class LongformerAttentionTest(keras_parameterized.TestCase):
def setUp(self):
super(LongformerAttentionTest, self).setUp()
np.random.seed(0)
tf.random.set_seed(0)
def _get_hidden_states(self): def _get_hidden_states(self):
return tf.convert_to_tensor( return tf.convert_to_tensor(
[ [
...@@ -116,25 +123,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase): ...@@ -116,25 +123,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
def test_diagonalize(self): def test_diagonalize(self):
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
hidden_states = tf.reshape(hidden_states, (1, 8, 4)) # set seq length = 8, hidden dim = 4 hidden_states = tf.reshape(hidden_states,
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2) (1, 8, 4)) # set seq length = 8, hidden dim = 4
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(
hidden_states, window_overlap=2)
window_overlap_size = get_shape_list(chunked_hidden_states)[2] window_overlap_size = get_shape_list(chunked_hidden_states)[2]
self.assertTrue(window_overlap_size == 4) self.assertTrue(window_overlap_size == 4)
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_diagonalize(chunked_hidden_states) padded_hidden_states = longformer_attention.LongformerAttention._pad_and_diagonalize(
chunked_hidden_states)
self.assertTrue( self.assertTrue(
get_shape_list(padded_hidden_states)[-1] == get_shape_list(chunked_hidden_states)[-1] + window_overlap_size - 1 get_shape_list(padded_hidden_states)[-1] ==
get_shape_list(chunked_hidden_states)[-1] + window_overlap_size - 1
) )
# first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000] # first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000]
tf.debugging.assert_near(padded_hidden_states[0, 0, 0, :4], chunked_hidden_states[0, 0, 0], rtol=1e-3) tf.debugging.assert_near(padded_hidden_states[0, 0, 0, :4],
tf.debugging.assert_near(padded_hidden_states[0, 0, 0, 4:], tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3) chunked_hidden_states[0, 0, 0], rtol=1e-3)
tf.debugging.assert_near(padded_hidden_states[0, 0, 0, 4:],
tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3)
# last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629] # last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629]
tf.debugging.assert_near(padded_hidden_states[0, 0, -1, 3:], chunked_hidden_states[0, 0, -1], rtol=1e-3) tf.debugging.assert_near(padded_hidden_states[0, 0, -1, 3:],
chunked_hidden_states[0, 0, -1], rtol=1e-3)
tf.debugging.assert_near( tf.debugging.assert_near(
padded_hidden_states[0, 0, -1, :3], tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3 padded_hidden_states[0, 0, -1, :3],
tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3
) )
def test_pad_and_transpose_last_two_dims(self): def test_pad_and_transpose_last_two_dims(self):
...@@ -142,16 +157,21 @@ class LongformerAttentionTest(keras_parameterized.TestCase): ...@@ -142,16 +157,21 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
self.assertTrue(get_shape_list(hidden_states), [1, 8, 4]) self.assertTrue(get_shape_list(hidden_states), [1, 8, 4])
# pad along seq length dim # pad along seq length dim
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]],
dtype=tf.dtypes.int32)
hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2) hidden_states = longformer_attention.LongformerAttention._chunk(
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_transpose_last_two_dims(hidden_states, paddings) hidden_states, window_overlap=2)
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_transpose_last_two_dims(
hidden_states, paddings)
self.assertTrue(get_shape_list(padded_hidden_states) == [1, 1, 8, 5]) self.assertTrue(get_shape_list(padded_hidden_states) == [1, 1, 8, 5])
expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32) expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6) tf.debugging.assert_near(expected_added_dim,
padded_hidden_states[0, 0, -1, :], rtol=1e-6)
tf.debugging.assert_near( tf.debugging.assert_near(
hidden_states[0, 0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6 hidden_states[0, 0, -1, :],
tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
) )
def test_mask_invalid_locations(self): def test_mask_invalid_locations(self):
...@@ -159,39 +179,55 @@ class LongformerAttentionTest(keras_parameterized.TestCase): ...@@ -159,39 +179,55 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
batch_size = 1 batch_size = 1
seq_length = 8 seq_length = 8
hidden_size = 4 hidden_size = 4
hidden_states = tf.reshape(hidden_states, (batch_size, seq_length, hidden_size)) hidden_states = tf.reshape(hidden_states,
hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2) (batch_size, seq_length, hidden_size))
hidden_states = longformer_attention.LongformerAttention._chunk(
hid_states_1 = longformer_attention.LongformerAttention._mask_invalid_locations(hidden_states, 1) hidden_states, window_overlap=2)
hid_states_2 = longformer_attention.LongformerAttention._mask_invalid_locations(hidden_states, 2)
hid_states_3 = longformer_attention.LongformerAttention._mask_invalid_locations(hidden_states[:, :, :, :3], 2) hid_states_1 = longformer_attention.LongformerAttention._mask_invalid_locations(
hid_states_4 = longformer_attention.LongformerAttention._mask_invalid_locations(hidden_states[:, :, 2:, :], 2) hidden_states, 1)
hid_states_2 = longformer_attention.LongformerAttention._mask_invalid_locations(
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_1), tf.dtypes.int32)) == 8) hidden_states, 2)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_2), tf.dtypes.int32)) == 24) hid_states_3 = longformer_attention.LongformerAttention._mask_invalid_locations(
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_3), tf.dtypes.int32)) == 24) hidden_states[:, :, :, :3], 2)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_4), tf.dtypes.int32)) == 12) hid_states_4 = longformer_attention.LongformerAttention._mask_invalid_locations(
hidden_states[:, :, 2:, :], 2)
self.assertTrue(tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_1), tf.dtypes.int32)) == 8)
self.assertTrue(tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_2), tf.dtypes.int32)) == 24)
self.assertTrue(tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_3), tf.dtypes.int32)) == 24)
self.assertTrue(tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_4), tf.dtypes.int32)) == 12)
def test_chunk(self): def test_chunk(self):
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
batch_size = 1 batch_size = 1
seq_length = 8 seq_length = 8
hidden_size = 4 hidden_size = 4
hidden_states = tf.reshape(hidden_states, (batch_size, seq_length, hidden_size)) hidden_states = tf.reshape(hidden_states,
(batch_size, seq_length, hidden_size))
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2) chunked_hidden_states = longformer_attention.LongformerAttention._chunk(
hidden_states, window_overlap=2)
# expected slices across chunk and seq length dim # expected slices across chunk and seq length dim
expected_slice_along_seq_length = tf.convert_to_tensor([0.4983, -0.7584, -1.6944], dtype=tf.dtypes.float32) expected_slice_along_seq_length = tf.convert_to_tensor(
expected_slice_along_chunk = tf.convert_to_tensor([0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.dtypes.float32) [0.4983, -0.7584, -1.6944], dtype=tf.dtypes.float32)
expected_slice_along_chunk = tf.convert_to_tensor(
[0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.dtypes.float32)
self.assertTrue(get_shape_list(chunked_hidden_states) == [1, 3, 4, 4]) self.assertTrue(get_shape_list(chunked_hidden_states) == [1, 3, 4, 4])
tf.debugging.assert_near(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, rtol=1e-3) tf.debugging.assert_near(chunked_hidden_states[0, :, 0, 0],
tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, rtol=1e-3) expected_slice_along_seq_length, rtol=1e-3)
tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0],
expected_slice_along_chunk, rtol=1e-3)
def test_layer_local_attn(self): def test_layer_local_attn(self):
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
batch_size, seq_length, hidden_size = hidden_states.shape batch_size, seq_length, _ = hidden_states.shape
layer = longformer_attention.LongformerAttention( layer = longformer_attention.LongformerAttention(
num_heads=2, num_heads=2,
key_dim=4, key_dim=4,
...@@ -203,14 +239,15 @@ class LongformerAttentionTest(keras_parameterized.TestCase): ...@@ -203,14 +239,15 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32) attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32)
is_index_global_attn = tf.math.greater(attention_mask, 1) is_index_global_attn = tf.math.greater(attention_mask, 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None]) attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0,
attention_mask[:, :, None, None])
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0) is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
output_hidden_states = layer( output_hidden_states = layer(
hidden_states=hidden_states, attention_mask=attention_mask, hidden_states=hidden_states, attention_mask=attention_mask,
is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
)[0] )[0]
self.assertTrue(output_hidden_states.shape, (1, 4, 8)) self.assertTrue(output_hidden_states.shape, (1, 4, 8))
...@@ -226,32 +263,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase): ...@@ -226,32 +263,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
) )
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
hidden_states = tf.concat([self._get_hidden_states(), self._get_hidden_states() - 0.5], axis=0) hidden_states = tf.concat(
[self._get_hidden_states(), self._get_hidden_states() - 0.5], axis=0)
batch_size, seq_length, hidden_size = hidden_states.shape batch_size, seq_length, hidden_size = hidden_states.shape
# create attn mask # create attn mask
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32) attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32) attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] == 0, 10000.0, attention_mask_1) attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] == 0, 10000.0,
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1) attention_mask_1)
attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] == 0, 10000.0, attention_mask_2) attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0,
attention_mask_1)
attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] == 0, 10000.0,
attention_mask_2)
attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0) attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0)
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0) is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0) is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
output_hidden_states = layer( output_hidden_states = layer(
hidden_states=hidden_states, attention_mask=-tf.math.abs(attention_mask), hidden_states=hidden_states, attention_mask=-tf.math.abs(attention_mask),
is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
)[0] )[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8)) self.assertTrue(output_hidden_states.shape, (2, 4, 8))
if __name__ == "__main__": if __name__ == '__main__':
np.random.seed(0)
tf.random.set_seed(0)
tf.test.main() tf.test.main()
...@@ -23,29 +23,16 @@ from absl import logging ...@@ -23,29 +23,16 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.projects.longformer.longformer_encoder_block import LongformerEncoderBlock from official.projects.longformer.longformer_encoder_block import \
LongformerEncoderBlock
from official.modeling.tf_utils import get_shape_list from official.modeling.tf_utils import get_shape_list
_Initializer = Union[str, tf.keras.initializers.Initializer] _Initializer = Union[str, tf.keras.initializers.Initializer]
_approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True) _approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True)
# Transferred from huggingface.longformer.TFLongformerMainLayer & TFLongformerEncoder
class LongformerEncoder(tf.keras.layers.Layer): class LongformerEncoder(tf.keras.layers.Layer):
"""Bi-directional Transformer-based encoder network. """LongformerEncoder
This network implements a bi-directional Transformer-based encoder as
described in "BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding" (https://arxiv.org/abs/1810.04805). It includes the
embedding lookups and transformer layers, but not the masked language model
or classification task networks.
The default values for this object are taken from the BERT-Base implementation
in "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding".
Args: Args:
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
attention_window: list of ints representing the window size for each layer. attention_window: list of ints representing the window size for each layer.
...@@ -165,15 +152,14 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -165,15 +152,14 @@ class LongformerEncoder(tf.keras.layers.Layer):
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
inner_dim=inner_dim, inner_dim=inner_dim,
inner_activation=inner_activation, inner_activation=inner_activation,
# Longformer, instead of passing a list of attention_window, pass a value to sub-block attention_window=attention_window[i],
attention_window=attention_window if isinstance(attention_window, int) else attention_window[i],
layer_id=i, layer_id=i,
output_dropout=output_dropout, output_dropout=output_dropout,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
norm_first=norm_first, norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None, output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer, kernel_initializer=initializer,
name='transformer/layer_%d' % i) name=f'transformer/layer_{i}')
self._transformer_layers.append(layer) self._transformer_layers.append(layer)
self._pooler_layer = tf.keras.layers.Dense( self._pooler_layer = tf.keras.layers.Dense(
...@@ -198,7 +184,6 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -198,7 +184,6 @@ class LongformerEncoder(tf.keras.layers.Layer):
'embedding_width': embedding_width, 'embedding_width': embedding_width,
'embedding_layer': embedding_layer, 'embedding_layer': embedding_layer,
'norm_first': norm_first, 'norm_first': norm_first,
# Longformer
'attention_window': attention_window, 'attention_window': attention_window,
'global_attention_size': global_attention_size, 'global_attention_size': global_attention_size,
'pad_token_id': pad_token_id, 'pad_token_id': pad_token_id,
...@@ -214,9 +199,10 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -214,9 +199,10 @@ class LongformerEncoder(tf.keras.layers.Layer):
word_ids = inputs.get('input_word_ids') # input_ids word_ids = inputs.get('input_word_ids') # input_ids
mask = inputs.get('input_mask') # attention_mask mask = inputs.get('input_mask') # attention_mask
type_ids = inputs.get('input_type_ids') # token_type_ids type_ids = inputs.get('input_type_ids') # token_type_ids
word_embeddings = inputs.get('input_word_embeddings', None) # input_embeds word_embeddings = inputs.get('input_word_embeddings',
None) # input_embeds
else: else:
raise ValueError('Unexpected inputs type to %s.' % self.__class__) raise ValueError(f'Unexpected inputs type to {self.__class__}.')
( (
padding_len, padding_len,
...@@ -247,34 +233,35 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -247,34 +233,35 @@ class LongformerEncoder(tf.keras.layers.Layer):
batch_size, seq_len = get_shape_list(mask) batch_size, seq_len = get_shape_list(mask)
# create masks with fixed len global_attention_size # create masks with fixed len global_attention_size
mask = tf.transpose(tf.concat(values=[tf.ones((self._global_attention_size, batch_size), tf.int32) * 2, mask = tf.transpose(tf.concat(
values=[tf.ones((self._global_attention_size, batch_size), tf.int32) * 2,
tf.transpose(mask)[self._global_attention_size:]], axis=0)) tf.transpose(mask)[self._global_attention_size:]], axis=0))
is_index_masked = tf.math.less(mask, 1) is_index_masked = tf.math.less(mask, 1)
is_index_global_attn = tf.transpose(tf.concat(values=[ is_index_global_attn = tf.transpose(tf.concat(values=[
tf.ones((self._global_attention_size, batch_size), tf.bool), tf.zeros((seq_len - self._global_attention_size, tf.ones((self._global_attention_size, batch_size), tf.bool),
tf.zeros((seq_len - self._global_attention_size,
batch_size), tf.bool) batch_size), tf.bool)
], axis=0)) ], axis=0))
is_global_attn = self._global_attention_size > 0
# Longformer # Longformer
attention_mask = mask attention_mask = mask
extended_attention_mask = tf.reshape( extended_attention_mask = tf.reshape(
attention_mask, (tf.shape(mask)[0], tf.shape(mask)[1], 1, 1) attention_mask, (tf.shape(mask)[0], tf.shape(mask)[1], 1, 1)
) )
attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0 attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask),
tf.dtypes.float32) * -10000.0
encoder_outputs = [] encoder_outputs = []
x = embeddings x = embeddings
# TFLongformerEncoder # TFLongformerEncoder
for i, layer in enumerate(self._transformer_layers): for layer in self._transformer_layers:
x = layer([ x = layer([
x, x,
attention_mask, attention_mask,
is_index_masked, is_index_masked,
is_index_global_attn, is_index_global_attn])
is_global_attn])
encoder_outputs.append(x) encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1] last_encoder_output = encoder_outputs[-1]
...@@ -328,19 +315,19 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -328,19 +315,19 @@ class LongformerEncoder(tf.keras.layers.Layer):
word_embeddings, word_embeddings,
pad_token_id, pad_token_id,
): ):
"""A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
# padding # padding
attention_window = ( attention_window = max(self._attention_window)
self._attention_window if isinstance(self._attention_window, int) else max(self._attention_window)
)
assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" assert attention_window % 2 == 0, \
f'`attention_window` should be an even value. Given {attention_window}'
input_shape = get_shape_list(word_ids) if word_ids is not None else get_shape_list(word_embeddings) input_shape = get_shape_list(
word_ids) if word_ids is not None else get_shape_list(word_embeddings)
batch_size, seq_len = input_shape[:2] batch_size, seq_len = input_shape[:2]
if seq_len is not None: if seq_len is not None:
padding_len = (attention_window - seq_len % attention_window) % attention_window padding_len = (attention_window -
seq_len % attention_window) % attention_window
else: else:
padding_len = 0 padding_len = 0
...@@ -355,10 +342,13 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -355,10 +342,13 @@ class LongformerEncoder(tf.keras.layers.Layer):
word_embeddings_padding = self._embedding_layer(word_ids_padding) word_embeddings_padding = self._embedding_layer(word_ids_padding)
return tf.concat([word_embeddings, word_embeddings_padding], axis=-2) return tf.concat([word_embeddings, word_embeddings_padding], axis=-2)
word_embeddings = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: word_embeddings) word_embeddings = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings,
lambda: word_embeddings)
mask = tf.pad(mask, paddings, constant_values=False) # no attention on the padding tokens mask = tf.pad(mask, paddings,
token_type_ids = tf.pad(type_ids, paddings, constant_values=0) # pad with token_type_id = 0 constant_values=False) # no attention on the padding tokens
token_type_ids = tf.pad(type_ids, paddings,
constant_values=0) # pad with token_type_id = 0
return ( return (
padding_len, padding_len,
......
...@@ -17,49 +17,13 @@ Longformer attention layer. Modified From huggingface/transformers ...@@ -17,49 +17,13 @@ Longformer attention layer. Modified From huggingface/transformers
""" """
import tensorflow as tf import tensorflow as tf
from official.projects.longformer.longformer_attention import LongformerAttention from official.projects.longformer.longformer_attention import \
LongformerAttention
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
class LongformerEncoderBlock(tf.keras.layers.Layer): class LongformerEncoderBlock(tf.keras.layers.Layer):
"""TransformerEncoderBlock layer. """LongformerEncoderBlock.
This layer implements the Transformer Encoder from
"Attention Is All You Need". (https://arxiv.org/abs/1706.03762),
which combines a `tf.keras.layers.MultiHeadAttention` layer with a
two-layer feedforward network.
References:
[Attention Is All You Need](https://arxiv.org/abs/1706.03762)
[BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding](https://arxiv.org/abs/1810.04805)
"""
def __init__(self,
global_attention_size,
num_attention_heads,
inner_dim,
inner_activation,
# Longformer
attention_window,
layer_id=0,
output_range=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
output_dropout=0.0,
attention_dropout=0.0,
inner_dropout=0.0,
attention_initializer=None,
attention_axes=None,
**kwargs):
"""Initializes `TransformerEncoderBlock`.
Args: Args:
num_attention_heads: Number of attention heads. num_attention_heads: Number of attention heads.
...@@ -94,6 +58,32 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -94,6 +58,32 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
attention over all axes, but batch, heads, and features. attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/ **kwargs: keyword arguments/
""" """
def __init__(self,
global_attention_size,
num_attention_heads,
inner_dim,
inner_activation,
# Longformer
attention_window,
layer_id=0,
output_range=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
output_dropout=0.0,
attention_dropout=0.0,
inner_dropout=0.0,
attention_initializer=None,
attention_axes=None,
**kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.global_attention_size = global_attention_size self.global_attention_size = global_attention_size
...@@ -133,16 +123,16 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -133,16 +123,16 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
input_tensor_shape = tf.TensorShape(input_shape[0]) input_tensor_shape = tf.TensorShape(input_shape[0])
else: else:
raise ValueError( raise ValueError(
"The type of input shape argument is not supported, got: %s" % f"The type of input shape argument is not supported, got: "
type(input_shape)) f"{type(input_shape)}")
einsum_equation = "abc,cd->abd" einsum_equation = "abc,cd->abd"
if len(input_tensor_shape.as_list()) > 3: if len(input_tensor_shape.as_list()) > 3:
einsum_equation = "...bc,cd->...bd" einsum_equation = "...bc,cd->...bd"
hidden_size = input_tensor_shape[-1] hidden_size = input_tensor_shape[-1]
if hidden_size % self._num_heads != 0: if hidden_size % self._num_heads != 0:
raise ValueError( raise ValueError(
"The input size (%d) is not a multiple of the number of attention " f"The input size ({hidden_size}) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads)) f"heads ({self._num_heads})")
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict( common_kwargs = dict(
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
...@@ -216,7 +206,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -216,7 +206,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
dtype=tf.float32) dtype=tf.float32)
super(LongformerEncoderBlock, self).build(input_shape) super().build(input_shape)
def get_config(self): def get_config(self):
config = { config = {
...@@ -258,7 +248,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -258,7 +248,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
tf.keras.initializers.serialize(self._attention_initializer), tf.keras.initializers.serialize(self._attention_initializer),
"attention_axes": self._attention_axes, "attention_axes": self._attention_axes,
} }
base_config = super(LongformerEncoderBlock, self).get_config() base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs): def call(self, inputs):
...@@ -277,26 +267,23 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -277,26 +267,23 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
An output tensor with the same dimensions as input/query tensor. An output tensor with the same dimensions as input/query tensor.
""" """
if isinstance(inputs, (list, tuple)): if isinstance(inputs, (list, tuple)):
if len(inputs) == 5: if len(inputs) == 4:
( (
input_tensor, input_tensor,
attention_mask, attention_mask,
is_index_masked, is_index_masked,
is_index_global_attn, is_index_global_attn,
is_global_attn
) = inputs ) = inputs
key_value = None key_value = None
elif len(inputs) == 6: elif len(inputs) == 5:
assert False # No key_value assert False # No key_value
else: else:
raise ValueError("Unexpected inputs to %s with length at %d" % raise ValueError(f"Unexpected inputs to {self.__class__} with length at {len(inputs)}")
(self.__class__, len(inputs)))
else: else:
input_tensor = inputs input_tensor = inputs
attention_mask = None attention_mask = None
is_index_masked = None is_index_masked = None
is_index_global_attn = None is_index_global_attn = None
is_global_attn = None
key_value = None key_value = None
if self._output_range: if self._output_range:
...@@ -329,7 +316,6 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -329,7 +316,6 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn
) )
# TFLongformerAttention.TFLongformerSelfOutput.* - {.dense} # TFLongformerAttention.TFLongformerSelfOutput.* - {.dense}
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
......
...@@ -12,44 +12,55 @@ ...@@ -12,44 +12,55 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tests for official.nlp.projects.bigbird.encoder.""" """Tests for official.nlp.projects.longformer.longformer_encoder."""
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import \
keras_parameterized # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from official.projects.longformer.longformer_encoder import LongformerEncoder from official.projects.longformer.longformer_encoder import LongformerEncoder
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class LongformerEncoderTest(keras_parameterized.TestCase): class LongformerEncoderTest(keras_parameterized.TestCase):
def setUp(self):
super(LongformerEncoderTest, self).setUp()
np.random.seed(0)
tf.random.set_seed(0)
@combinations.generate(combinations.combine( @combinations.generate(combinations.combine(
attention_window=[32, 128], global_attention_size=[0, 1, 2])) attention_window=[32, 128], global_attention_size=[0, 1, 2]))
def test_encoder(self, attention_window, global_attention_size): def test_encoder(self, attention_window, global_attention_size):
sequence_length = 128 sequence_length = 128
batch_size = 2 batch_size = 2
vocab_size = 1024 vocab_size = 1024
hidden_size=256 hidden_size = 256
network = LongformerEncoder( network = LongformerEncoder(
global_attention_size=global_attention_size, global_attention_size=global_attention_size,
vocab_size=vocab_size, vocab_size=vocab_size,
attention_window=attention_window, attention_window=[attention_window],
hidden_size=hidden_size, hidden_size=hidden_size,
num_layers=1, num_layers=1,
num_attention_heads=4, num_attention_heads=4,
max_sequence_length=512) max_sequence_length=512)
word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length), dtype=np.int32) word_id_data = np.random.randint(vocab_size,
mask_data = np.random.randint(2, size=(batch_size, sequence_length), dtype=np.int32) size=(batch_size, sequence_length),
type_id_data = np.random.randint(2, size=(batch_size, sequence_length), dtype=np.int32) dtype=np.int32)
mask_data = np.random.randint(2, size=(batch_size, sequence_length),
dtype=np.int32)
type_id_data = np.random.randint(2, size=(batch_size, sequence_length),
dtype=np.int32)
inputs = { inputs = {
'input_word_ids': word_id_data, 'input_word_ids': word_id_data,
'input_mask': mask_data, 'input_mask': mask_data,
'input_type_ids': type_id_data, 'input_type_ids': type_id_data,
} }
outputs = network(inputs) outputs = network(inputs)
self.assertEqual(outputs["sequence_output"].shape, self.assertEqual(outputs['sequence_output'].shape,
(batch_size, sequence_length, hidden_size)) (batch_size, sequence_length, hidden_size))
@combinations.generate(combinations.combine( @combinations.generate(combinations.combine(
...@@ -62,24 +73,28 @@ class LongformerEncoderTest(keras_parameterized.TestCase): ...@@ -62,24 +73,28 @@ class LongformerEncoderTest(keras_parameterized.TestCase):
network = LongformerEncoder( network = LongformerEncoder(
global_attention_size=global_attention_size, global_attention_size=global_attention_size,
vocab_size=vocab_size, vocab_size=vocab_size,
attention_window=32, attention_window=[32],
hidden_size=hidden_size, hidden_size=hidden_size,
num_layers=1, num_layers=1,
num_attention_heads=4, num_attention_heads=4,
max_sequence_length=512, max_sequence_length=512,
norm_first=norm_first) norm_first=norm_first)
word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length), dtype=np.int32) word_id_data = np.random.randint(vocab_size,
mask_data = np.random.randint(2, size=(batch_size, sequence_length), dtype=np.int32) size=(batch_size, sequence_length),
type_id_data = np.random.randint(2, size=(batch_size, sequence_length), dtype=np.int32) dtype=np.int32)
mask_data = np.random.randint(2, size=(batch_size, sequence_length),
dtype=np.int32)
type_id_data = np.random.randint(2, size=(batch_size, sequence_length),
dtype=np.int32)
inputs = { inputs = {
'input_word_ids': word_id_data, 'input_word_ids': word_id_data,
'input_mask': mask_data, 'input_mask': mask_data,
'input_type_ids': type_id_data, 'input_type_ids': type_id_data,
} }
outputs = network(inputs) outputs = network(inputs)
self.assertEqual(outputs["sequence_output"].shape, self.assertEqual(outputs['sequence_output'].shape,
(batch_size, sequence_length, hidden_size)) (batch_size, sequence_length, hidden_size))
if __name__ == "__main__": if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -34,22 +34,24 @@ AdamWeightDecay = optimization.AdamWeightDecayConfig ...@@ -34,22 +34,24 @@ AdamWeightDecay = optimization.AdamWeightDecayConfig
PolynomialLr = optimization.PolynomialLrConfig PolynomialLr = optimization.PolynomialLrConfig
PolynomialWarmupConfig = optimization.PolynomialWarmupConfig PolynomialWarmupConfig = optimization.PolynomialWarmupConfig
@dataclasses.dataclass @dataclasses.dataclass
class LongformerOptimizationConfig(optimization.OptimizationConfig): class LongformerOptimizationConfig(optimization.OptimizationConfig):
optimizer: optimization.OptimizerConfig = optimization.OptimizerConfig( optimizer: optimization.OptimizerConfig = optimization.OptimizerConfig(
type="adamw", type='adamw',
adamw=AdamWeightDecay( adamw=AdamWeightDecay(
weight_decay_rate=0.01, weight_decay_rate=0.01,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'],
epsilon=1e-6)) epsilon=1e-6))
learning_rate: optimization.LrConfig = optimization.LrConfig( learning_rate: optimization.LrConfig = optimization.LrConfig(
type="polynomial", type='polynomial',
polynomial=PolynomialLr( polynomial=PolynomialLr(
initial_learning_rate=1e-4, initial_learning_rate=1e-4,
decay_steps=1000000, decay_steps=1000000,
end_learning_rate=0.0)) end_learning_rate=0.0))
warmup: optimization.WarmupConfig = optimization.WarmupConfig( warmup: optimization.WarmupConfig = optimization.WarmupConfig(
type="polynomial", polynomial=PolynomialWarmupConfig(warmup_steps=10000)) type='polynomial', polynomial=PolynomialWarmupConfig(warmup_steps=10000))
@exp_factory.register_config_factory('longformer/pretraining') @exp_factory.register_config_factory('longformer/pretraining')
def longformer_pretraining() -> cfg.ExperimentConfig: def longformer_pretraining() -> cfg.ExperimentConfig:
...@@ -62,11 +64,14 @@ def longformer_pretraining() -> cfg.ExperimentConfig: ...@@ -62,11 +64,14 @@ def longformer_pretraining() -> cfg.ExperimentConfig:
type="any", any=LongformerEncoderConfig()), type="any", any=LongformerEncoderConfig()),
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence') inner_dim=768, num_classes=2, dropout_rate=0.1,
name='next_sentence')
] ]
), ),
train_data=pretrain_dataloader.BertPretrainDataConfig(use_v2_feature_names=True), train_data=pretrain_dataloader.BertPretrainDataConfig(
validation_data=pretrain_dataloader.BertPretrainDataConfig(use_v2_feature_names=True, use_v2_feature_names=True),
validation_data=pretrain_dataloader.BertPretrainDataConfig(
use_v2_feature_names=True,
is_training=False)), is_training=False)),
trainer=cfg.TrainerConfig( trainer=cfg.TrainerConfig(
optimizer_config=LongformerOptimizationConfig(), train_steps=1000000), optimizer_config=LongformerOptimizationConfig(), train_steps=1000000),
...@@ -76,6 +81,7 @@ def longformer_pretraining() -> cfg.ExperimentConfig: ...@@ -76,6 +81,7 @@ def longformer_pretraining() -> cfg.ExperimentConfig:
]) ])
return config return config
@exp_factory.register_config_factory('longformer/glue') @exp_factory.register_config_factory('longformer/glue')
def longformer_glue() -> cfg.ExperimentConfig: def longformer_glue() -> cfg.ExperimentConfig:
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
......
...@@ -24,7 +24,6 @@ from official.core import task_factory ...@@ -24,7 +24,6 @@ from official.core import task_factory
from official.core import train_lib from official.core import train_lib
from official.core import train_utils from official.core import train_utils
from official.modeling import performance from official.modeling import performance
from official.projects.longformer import longformer_experiments
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -43,7 +42,8 @@ def main(_): ...@@ -43,7 +42,8 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype) performance.set_mixed_precision_policy(
params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment