Commit 09e6e71c authored by Zihan Wang's avatar Zihan Wang
Browse files

lint

parent 32867f40
......@@ -18,27 +18,20 @@ Longformer attention block. Modified From huggingface/transformers
# pylint: disable=g-classes-have-attributes
import collections
import math
import string
import tensorflow as tf
import numpy as np
from keras import constraints
from keras import initializers
from keras import regularizers
from keras.engine.base_layer import Layer
from keras.layers import core
from keras.layers import einsum_dense
from keras.utils import tf_utils
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export
from official.modeling.tf_utils import get_shape_list
from typing import Dict, List, Optional, Union
_CHR_IDX = string.ascii_lowercase
def _build_attention_equation(rank, attn_axes):
"""Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as:
......@@ -64,7 +57,7 @@ def _build_attention_equation(rank, attn_axes):
# `batch_dims` includes the head dim.
batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
letter_offset = rank
source_notation = ""
source_notation = ''
for i in range(rank):
if i in batch_dims or i == rank - 1:
source_notation += target_notation[i]
......@@ -72,23 +65,21 @@ def _build_attention_equation(rank, attn_axes):
source_notation += _CHR_IDX[letter_offset]
letter_offset += 1
product_notation = "".join([target_notation[i] for i in batch_dims] +
product_notation = ''.join([target_notation[i] for i in batch_dims] +
[target_notation[i] for i in attn_axes] +
[source_notation[i] for i in attn_axes])
dot_product_equation = "%s,%s->%s" % (source_notation, target_notation,
product_notation)
dot_product_equation = f'{source_notation},{target_notation}->{product_notation}'
attn_scores_rank = len(product_notation)
combine_equation = "%s,%s->%s" % (product_notation, source_notation,
target_notation)
combine_equation = f'{product_notation},{source_notation}->{target_notation}'
return dot_product_equation, combine_equation, attn_scores_rank
def _build_proj_equation(free_dims, bound_dims, output_dims):
"""Builds an einsum equation for projections inside multi-head attention."""
input_str = ""
kernel_str = ""
output_str = ""
bias_axes = ""
input_str = ''
kernel_str = ''
output_str = ''
bias_axes = ''
letter_offset = 0
for i in range(free_dims):
char = _CHR_IDX[i + letter_offset]
......@@ -107,7 +98,7 @@ def _build_proj_equation(free_dims, bound_dims, output_dims):
kernel_str += char
output_str += char
bias_axes += char
equation = "%s,%s->%s" % (input_str, kernel_str, output_str)
equation = f'{input_str},{kernel_str}->{output_str}'
return equation, bias_axes, len(output_str)
......@@ -115,8 +106,17 @@ def _build_proj_equation(free_dims, bound_dims, output_dims):
def _get_output_shape(output_rank, known_last_dims):
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
@tf.keras.utils.register_keras_serializable(package="Text")
class LongformerAttention(tf.keras.layers.MultiHeadAttention):
"""LongformerAttention
Args:
attention_window: int representing the window size for attention.
layer_id: int of the id of the layer.
global_attention_size: the size of global attention used for each token.
"""
def __init__(self,
attention_window,
layer_id,
......@@ -124,14 +124,16 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
**kwargs):
super().__init__(**kwargs)
self._layer_id = layer_id
_attention_window = attention_window
self._attention_window = attention_window
assert (
_attention_window % 2 == 0
), f"`attention_window` for layer {self._layer_id} has to be an even value. Given {attention_window}"
self._attention_window % 2 == 0
), f"`attention_window` for layer {self._layer_id} has to be an even " \
f"value. Given {self.attention_window}"
assert (
_attention_window > 0
), f"`attention_window` for layer {self._layer_id} has to be positive. Given {attention_window}"
self._one_sided_attn_window_size = _attention_window // 2
self._attention_window > 0
), f"`attention_window` for layer {self._layer_id} has to be positive. " \
f"Given {self.attention_window}"
self._one_sided_attn_window_size = self._attention_window // 2
self.global_attention_size = global_attention_size
def _build_from_signature(self, query, value, key=None):
......@@ -228,16 +230,15 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# self._output_dense = self._make_output_dense(
# free_dims, common_kwargs, "attention_output")
self._output_dense = tf.keras.layers.Dense(
units=self._num_heads * self._key_dim, name="dense",
**common_kwargs
)
units=self._num_heads * self._key_dim, name="dense",
**common_kwargs
)
def call(self,
hidden_states,
attention_mask=None,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
training=None):
"""Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected
......@@ -256,7 +257,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
attention_scores: Multi-headed attention weights.
"""
if not self._built_from_signature:
self._build_from_signature(query=hidden_states, value=hidden_states, key=hidden_states)
self._build_from_signature(query=hidden_states, value=hidden_states,
key=hidden_states)
# N = `num_attention_heads`
# H = `size_per_head`
......@@ -272,7 +274,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim))) # (B, T, N, key_dim)
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim)))
batch_size, seq_len, num_heads, head_dim = get_shape_list(query)
# attn_probs = (batch_size, seq_len, num_heads, window*2+1)
......@@ -293,8 +295,12 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(attn_scores),
[batch_size, seq_len, self._num_heads, self._one_sided_attn_window_size * 2 + 1],
message=f"attn_probs should be of size ({batch_size}, {seq_len}, {num_heads}, {self._one_sided_attn_window_size * 2 + 1}), but is of size {get_shape_list(attn_scores)}",
[batch_size, seq_len, self._num_heads,
self._one_sided_attn_window_size * 2 + 1],
message=f"attn_probs should be of size "
f"({batch_size}, {seq_len}, {num_heads}, "
f"{self._one_sided_attn_window_size * 2 + 1}),"
f" but is of size {get_shape_list(attn_scores)}",
)
# compute global attn indices required through out forward fn
......@@ -303,7 +309,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
) = self._get_global_attn_indices(is_index_global_attn, self.global_attention_size)
) = self._get_global_attn_indices(is_index_global_attn,
self.global_attention_size)
# this function is only relevant for global attention
if self.global_attention_size > 0:
attn_scores = self._concat_with_global_key_attn_probs(
......@@ -320,14 +327,18 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
attn_probs = tf.nn.softmax(attn_scores, axis=-1)
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
# softmax sometimes inserts NaN if all positions are masked,
# replace them with 0
# Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
# if is_global_attn==True => [batch_size, seq_len, self.num_heads,
# self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads,
# self.one_sided_attn_window_size * 2 + 1]
if self.global_attention_size > 0:
masked_index = tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self._num_heads, self._one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
(1, 1, self._num_heads,
self._one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
)
else:
masked_index = tf.tile(
......@@ -347,14 +358,17 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
tf.debugging.assert_equal(
get_shape_list(layer_head_mask),
[self._num_heads],
message=f"Head mask for a single layer should be of size {(self._num_heads)}, but is {get_shape_list(layer_head_mask)}",
message=f"Head mask for a single layer should be of size "
f"{(self._num_heads)}, but is "
f"{get_shape_list(layer_head_mask)}",
)
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
# apply dropout
attn_probs = self._dropout_layer(attn_probs, training=training)
value_vectors = tf.reshape(value, (batch_size, seq_len, self._num_heads, self._key_dim)) # TODO: _key_dim == _value_dim
value_vectors = tf.reshape(value, (batch_size, seq_len, self._num_heads,
self._key_dim))
# if global attention, compute sum of global and local attn
if self.global_attention_size > 0:
......@@ -377,33 +391,35 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
message="Unexpected size",
)
attn_output = tf.reshape(attn_output, (batch_size, seq_len, self._num_heads * self._key_dim)) # FIXME
attn_output = tf.reshape(attn_output, (
batch_size, seq_len, self._num_heads * self._key_dim)) # FIXME
# compute value for global attention and overwrite to attention output
# TODO: remove the redundant computation
if self.global_attention_size > 0:
attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
attn_output=attn_output,
hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices,
layer_head_mask=layer_head_mask,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
is_index_masked=is_index_masked,
training=training,
)
attn_output, global_attn_probs = \
self._compute_global_attn_output_from_hidden(
attn_output=attn_output,
hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices,
layer_head_mask=layer_head_mask,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
is_index_masked=is_index_masked,
training=training,
)
else:
global_attn_probs = tf.zeros((batch_size, self._num_heads, max_num_global_attn_indices, seq_len))
global_attn_probs = tf.zeros(
(batch_size, self._num_heads, max_num_global_attn_indices, seq_len))
# make sure that local attention probabilities are set to 0 for indices of global attn
# Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
# make sure that local attention probabilities are set to 0 for indices of
# global attn
if self.global_attention_size > 0:
masked_global_attn_index = tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self._num_heads, self._one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
(1, 1, self._num_heads,
self._one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
)
else:
masked_global_attn_index = tf.tile(
......@@ -413,28 +429,30 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
attn_probs = tf.where(
masked_global_attn_index,
tf.zeros(get_shape_list(masked_global_attn_index), dtype=attn_probs.dtype),
tf.zeros(get_shape_list(masked_global_attn_index),
dtype=attn_probs.dtype),
attn_probs,
)
# we can return extra information here
attention_output = attn_output # (attn_output, attn_probs, global_attn_probs)
attention_output = attn_output # (attn_output, attn_probs, global_attn_probs)
return attention_output
def get_config(self):
config = {
"layer_id": self._layer_id,
"attention_window": self._one_sided_attn_window_size,
"layer_id": self._layer_id,
"attention_window": self._one_sided_attn_window_size,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
"""
Matrix multiplication of query and key tensors using with a sliding window attention pattern. This
implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an
overlap of size window_overlap
Matrix multiplication of query and key tensors using with a sliding window
attention pattern. This implementation splits the input into overlapping
chunks of size 2w (e.g. 512 for pretrained Longformer) with an overlap of
size window_overlap
"""
batch_size, seq_len, num_heads, head_dim = get_shape_list(query)
......@@ -442,22 +460,26 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
tf.debugging.assert_equal(
seq_len % (window_overlap * 2),
0,
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
message=f"Sequence length should be multiple of {window_overlap * 2}. "
f"Given {seq_len}",
)
tf.debugging.assert_equal(
get_shape_list(query),
get_shape_list(key),
message=f"Shape of query and key should be equal, but got query: {get_shape_list(query)} and key: {get_shape_list(key)}",
message=f"Shape of query and key should be equal, but got query: "
f"{get_shape_list(query)} and key: {get_shape_list(key)}",
)
chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
# group batch_size and num_heads dimensions into one,
# then chunk seq_len into chunks of size window_overlap * 2
query = tf.reshape(
tf.transpose(query, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim),
)
key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))
key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim))
chunked_query = self._chunk(query, window_overlap)
chunked_key = self._chunk(key, window_overlap)
......@@ -466,24 +488,31 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
chunked_query = tf.cast(chunked_query, dtype=chunked_key.dtype)
chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply
chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query,
chunked_key) # multiply
# convert diagonals into columns
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]])
diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings)
# allocate space for the overall attention matrix where the chunks are combined. The last dimension
# has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to
# window_overlap previous words). The following column is attention score from each word to itself, then
diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(
chunked_attention_scores, paddings)
# allocate space for the overall attention matrix where the chunks are
# combined. The last dimension
# has (window_overlap * 2 + 1) columns. The first (window_overlap) columns
# are the window_overlap lower triangles (attention from a word to
# window_overlap previous words). The following column is attention score
# from each word to itself, then
# followed by window_overlap columns for the upper triangle.
# copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions
# - copying the main diagonal and the upper triangle
# copy parts from diagonal_chunked_attention_scores into the combined matrix
# of attentions - copying the main diagonal and the upper triangle
# TODO: This code is most likely not very efficient and should be improved
diagonal_attn_scores_up_triang = tf.concat(
[
diagonal_chunked_attention_scores[:, :, :window_overlap, : window_overlap + 1],
diagonal_chunked_attention_scores[:, -1:, window_overlap:, : window_overlap + 1],
diagonal_chunked_attention_scores[:, :, :window_overlap,
: window_overlap + 1],
diagonal_chunked_attention_scores[:, -1:, window_overlap:,
: window_overlap + 1],
],
axis=1,
)
......@@ -495,7 +524,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
(batch_size * num_heads, 1, window_overlap, window_overlap),
dtype=diagonal_chunked_attention_scores.dtype,
),
diagonal_chunked_attention_scores[:, :, -(window_overlap + 1): -1, window_overlap + 1:],
diagonal_chunked_attention_scores[:, :, -(window_overlap + 1): -1,
window_overlap + 1:],
],
axis=1,
)
......@@ -514,13 +544,13 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
axis=1,
)
first_chunk_mask = (
tf.tile(
tf.range(chunks_count + 1)[None, :, None, None],
(batch_size * num_heads, 1, window_overlap, window_overlap),
)
< 1
tf.tile(
tf.range(chunks_count + 1)[None, :, None, None],
(batch_size * num_heads, 1, window_overlap, window_overlap),
)
< 1
)
#first_chunk_mask = tf.repeat(first_chunk_mask, batch_size * num_heads, axis=0)
diagonal_attn_scores_low_triang = tf.where(
first_chunk_mask,
diagonal_attn_scores_first_chunk,
......@@ -541,7 +571,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
(0, 2, 1, 3),
)
diagonal_attention_scores = self._mask_invalid_locations(diagonal_attention_scores, window_overlap)
diagonal_attention_scores = self._mask_invalid_locations(
diagonal_attention_scores, window_overlap)
return diagonal_attention_scores
......@@ -549,13 +580,15 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
def _mask_invalid_locations(input_tensor, window_overlap):
# create correct upper triangle bool mask
mask_2d_upper = tf.reverse(
tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0),
tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)),
-1, 0),
axis=[0],
)
# pad to full matrix
padding = tf.convert_to_tensor(
[[0, get_shape_list(input_tensor)[1] - window_overlap], [0, get_shape_list(input_tensor)[3] - window_overlap - 1]]
[[0, get_shape_list(input_tensor)[1] - window_overlap],
[0, get_shape_list(input_tensor)[3] - window_overlap - 1]]
)
# create lower mask
......@@ -565,20 +598,23 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
# broadcast to full matrix
mask_4d = tf.tile(mask_2d[None, :, None, :], (get_shape_list(input_tensor)[0], 1, 1, 1))
mask_4d = tf.tile(mask_2d[None, :, None, :],
(get_shape_list(input_tensor)[0], 1, 1, 1))
# inf tensor used for masking
inf_tensor = -float("inf") * tf.ones_like(input_tensor)
# mask
input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor)
input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor,
input_tensor)
return input_tensor
def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_overlap):
def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value,
window_overlap):
"""
Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the
same shape as `attn_probs`
Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors.
Returned tensor will be of the same shape as `attn_probs`
"""
batch_size, seq_len, num_heads, head_dim = get_shape_list(value)
......@@ -602,7 +638,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
# group batch_size and num_heads dimensions into one, then chunk seq_len
# into chunks of size 2 window overlap
chunked_attn_probs = tf.reshape(
tf.transpose(attn_probs, (0, 2, 1, 3)),
(
......@@ -619,13 +656,17 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
(batch_size * num_heads, seq_len, head_dim),
)
# pad seq_len with w at the beginning of the sequence and another window overlap at the end
paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]])
# pad seq_len with w at the beginning of the sequence and another window
# overlap at the end
paddings = tf.convert_to_tensor(
[[0, 0], [window_overlap, window_overlap], [0, 0]])
padded_value = tf.pad(value, paddings, constant_values=-1)
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
# chunk padded_value into chunks of size 3 window overlap and an overlap of
# size window overlap
frame_size = 3 * window_overlap * head_dim
frame_hop_size = (get_shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count
frame_hop_size = (get_shape_list(padded_value)[
1] * head_dim - frame_size) // chunks_count
chunked_value = tf.signal.frame(
tf.reshape(padded_value, (batch_size * num_heads, -1)),
frame_size,
......@@ -639,7 +680,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(chunked_value),
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap,
head_dim],
message="Chunked value has the wrong shape",
)
......@@ -658,8 +700,10 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
hidden_states_padded = tf.pad(
hidden_states_padded, paddings
) # padding value is not important because it will be overwritten
batch_size, chunk_size, seq_length, hidden_dim = get_shape_list(hidden_states_padded)
hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length))
batch_size, chunk_size, seq_length, hidden_dim = get_shape_list(
hidden_states_padded)
hidden_states_padded = tf.reshape(hidden_states_padded, (
batch_size, chunk_size, hidden_dim, seq_length))
return hidden_states_padded
......@@ -681,21 +725,27 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
"""
total_num_heads, num_chunks, window_overlap, hidden_dim = get_shape_list(chunked_hidden_states)
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]])
total_num_heads, num_chunks, window_overlap, hidden_dim = get_shape_list(
chunked_hidden_states)
paddings = tf.convert_to_tensor(
[[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]])
chunked_hidden_states = tf.pad(
chunked_hidden_states, paddings
) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
)
chunked_hidden_states = tf.reshape(
chunked_hidden_states, (total_num_heads, num_chunks, -1)
) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap
)
chunked_hidden_states = chunked_hidden_states[
:, :, :-window_overlap
] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap
]
chunked_hidden_states = tf.reshape(
chunked_hidden_states,
(total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim),
) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap
(
total_num_heads, num_chunks, window_overlap,
window_overlap + hidden_dim),
)
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
return chunked_hidden_states
......@@ -709,16 +759,21 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# define frame size and frame stride (similar to convolution)
frame_hop_size = window_overlap * hidden_dim
frame_size = 2 * frame_hop_size
hidden_states = tf.reshape(hidden_states, (batch_size, seq_length * hidden_dim))
hidden_states = tf.reshape(hidden_states,
(batch_size, seq_length * hidden_dim))
# chunk with overlap
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size,
frame_hop_size)
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size],
message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size, frame_size, num_output_chunks]}, but got {get_shape_list(chunked_hidden_states)}.",
message=f"Make sure chunking is correctly applied. `Chunked hidden "
f"states should have output dimension"
f" {[batch_size, frame_size, num_output_chunks]}, but got "
f"{get_shape_list(chunked_hidden_states)}.",
)
chunked_hidden_states = tf.reshape(
......@@ -738,19 +793,25 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
max_num_global_attn_indices = global_attention_size
row_indices = tf.range(batch_size)
row_indices = tf.repeat(tf.expand_dims(row_indices, axis=0), repeats=[global_attention_size], axis=0)
row_indices = tf.reshape(row_indices, (batch_size * global_attention_size, 1))
row_indices = tf.repeat(tf.expand_dims(row_indices, axis=0),
repeats=[global_attention_size], axis=0)
row_indices = tf.reshape(row_indices,
(batch_size * global_attention_size, 1))
col_indices = tf.range(global_attention_size)
col_indices = tf.repeat(tf.expand_dims(col_indices, axis=1), repeats=[batch_size], axis=0)
col_indices = tf.repeat(tf.expand_dims(col_indices, axis=1),
repeats=[batch_size], axis=0)
is_index_global_attn_nonzero = tf.concat((row_indices, col_indices), axis=1)
# this is actually same as `is_index_global_attn_nonzero`, since we assume all global attention are the same size
is_local_index_global_attn_nonzero = tf.concat((row_indices, col_indices), axis=1)
# this is actually same as `is_index_global_attn_nonzero`,
# since we assume all global attention are the same size
is_local_index_global_attn_nonzero = tf.concat((row_indices, col_indices),
axis=1)
# empty tensor
is_local_index_no_global_attn_nonzero = tf.reshape(tf.expand_dims(tf.range(0), axis=1), (0, 2))
is_local_index_no_global_attn_nonzero = tf.reshape(
tf.expand_dims(tf.range(0), axis=1), (0, 2))
return (
max_num_global_attn_indices,
is_index_global_attn_nonzero,
......@@ -759,14 +820,14 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
)
def _concat_with_global_key_attn_probs(
self,
attn_scores,
key_vectors,
query_vectors,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
self,
attn_scores,
key_vectors,
query_vectors,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
):
batch_size = get_shape_list(key_vectors)[0]
......@@ -786,11 +847,14 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
)
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors, key_vectors_only_global)
attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors,
key_vectors_only_global)
# (batch_size, max_num_global_attn_indices, seq_len, num_heads)
attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2))
mask_shape = (get_shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key,
(0, 3, 1, 2))
mask_shape = (get_shape_list(is_local_index_no_global_attn_nonzero)[
0],) + tuple(
get_shape_list(attn_probs_from_global_key_trans)[-2:]
)
mask = tf.ones(mask_shape) * -10000.0
......@@ -804,7 +868,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
)
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key = tf.transpose(attn_probs_from_global_key_trans, (0, 2, 3, 1))
attn_probs_from_global_key = tf.transpose(attn_probs_from_global_key_trans,
(0, 2, 3, 1))
# concat to attn_probs
# (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
......@@ -812,21 +877,21 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
return attn_scores
def _compute_attn_output_with_global_indices(
self,
value_vectors,
attn_probs,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
self,
value_vectors,
attn_probs,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
):
batch_size = get_shape_list(attn_probs)[0]
# cut local attn probs to global only
attn_probs_only_global = attn_probs[:, :, :, :max_num_global_attn_indices]
# attn_probs_only_global = tf.slice(attn_probs, [0, 0, 0, 0], get_shape_list(attn_probs)[: -1] + [max_num_global_attn_indices])
# select global value vectors
global_value_vectors = tf.gather_nd(value_vectors, is_index_global_attn_nonzero)
global_value_vectors = tf.gather_nd(value_vectors,
is_index_global_attn_nonzero)
# create only global value vectors
value_vectors_only_global = tf.scatter_nd(
......@@ -841,10 +906,12 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
)
# compute attn output only global
attn_output_only_global = tf.einsum("blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global)
attn_output_only_global = tf.einsum("blhs,bshd->blhd",
attn_probs_only_global,
value_vectors_only_global)
# reshape attn probs
attn_probs_without_global = attn_probs[:, :, :, max_num_global_attn_indices:]
# attn_probs_without_global = tf.slice(attn_probs, [0, 0, 0, max_num_global_attn_indices], get_shape_list(attn_probs)[: -1] + [get_shape_list(attn_probs)[-1] - max_num_global_attn_indices])
attn_probs_without_global = attn_probs[:, :, :,
max_num_global_attn_indices:]
# compute attn output with global
attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
......@@ -854,29 +921,33 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
return attn_output_only_global + attn_output_without_global
def _compute_global_attn_output_from_hidden(
self,
attn_output,
hidden_states,
max_num_global_attn_indices,
layer_head_mask,
is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
is_index_masked,
training,
self,
attn_output,
hidden_states,
max_num_global_attn_indices,
layer_head_mask,
is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
is_index_masked,
training,
):
batch_size, seq_len = get_shape_list(hidden_states)[:2]
# prepare global hidden states
global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero)
global_attn_hidden_states = tf.gather_nd(hidden_states,
is_index_global_attn_nonzero)
global_attn_hidden_states = tf.scatter_nd(
is_local_index_global_attn_nonzero,
global_attn_hidden_states,
shape=(batch_size, max_num_global_attn_indices, self._num_heads * self._key_dim),
shape=(
batch_size, max_num_global_attn_indices,
self._num_heads * self._key_dim),
)
# global key, query, value
global_query_vectors_only_global = self._global_query_dense(global_attn_hidden_states)
global_query_vectors_only_global = self._global_query_dense(
global_attn_hidden_states)
global_key_vectors = self._global_key_dense(hidden_states)
global_value_vectors = self._global_value_dense(hidden_states)
......@@ -884,18 +955,24 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
global_query_vectors_only_global /= tf.math.sqrt(
tf.cast(self._key_dim, dtype=global_query_vectors_only_global.dtype)
)
global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size)
global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size)
global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size)
global_query_vectors_only_global = self.reshape_and_transpose(
global_query_vectors_only_global, batch_size)
global_key_vectors = self.reshape_and_transpose(global_key_vectors,
batch_size)
global_value_vectors = self.reshape_and_transpose(global_value_vectors,
batch_size)
# compute attn scores
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
global_attn_scores = tf.matmul(global_query_vectors_only_global,
global_key_vectors, transpose_b=True)
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(global_attn_scores),
[batch_size * self._num_heads, max_num_global_attn_indices, seq_len],
message=f"global_attn_scores have the wrong size. Size should be {(batch_size * self._num_heads, max_num_global_attn_indices, seq_len)}, but is {get_shape_list(global_attn_scores)}.",
message=f"global_attn_scores have the wrong size. Size should be"
f"{(batch_size * self._num_heads, max_num_global_attn_indices, seq_len)}, "
f"but is {get_shape_list(global_attn_scores)}.",
)
global_attn_scores = tf.reshape(
......@@ -903,11 +980,13 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
(batch_size, self._num_heads, max_num_global_attn_indices, seq_len),
)
global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3))
mask_shape = (get_shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
mask_shape = (get_shape_list(is_local_index_no_global_attn_nonzero)[
0],) + tuple(
get_shape_list(global_attn_scores_trans)[-2:]
)
global_attn_mask = tf.ones(mask_shape) * -10000.0
global_attn_mask = tf.cast(global_attn_mask, dtype=global_attn_scores_trans.dtype)
global_attn_mask = tf.cast(global_attn_mask,
dtype=global_attn_scores_trans.dtype)
# scatter mask
global_attn_scores_trans = tf.tensor_scatter_nd_update(
......@@ -916,9 +995,10 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
global_attn_mask,
)
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
# mask global attn scores
attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, get_shape_list(global_attn_scores)[1], 1, 1))
attn_mask = tf.tile(is_index_masked[:, None, None, :],
(1, get_shape_list(global_attn_scores)[1], 1, 1))
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
global_attn_scores = tf.reshape(
global_attn_scores,
......@@ -934,17 +1014,22 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
tf.debugging.assert_equal(
get_shape_list(layer_head_mask),
[self._num_heads],
message=f"Head mask for a single layer should be of size {(self._num_heads)}, but is {get_shape_list(layer_head_mask)}",
message=f"Head mask for a single layer should be of size "
f"{(self._num_heads)}, but is {get_shape_list(layer_head_mask)}",
)
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
global_attn_probs_float, (batch_size, self._num_heads, max_num_global_attn_indices, seq_len)
global_attn_probs_float = tf.reshape(layer_head_mask,
(1, -1, 1, 1)) * tf.reshape(
global_attn_probs_float,
(batch_size, self._num_heads, max_num_global_attn_indices, seq_len)
)
global_attn_probs_float = tf.reshape(
global_attn_probs_float, (batch_size * self._num_heads, max_num_global_attn_indices, seq_len)
global_attn_probs_float,
(batch_size * self._num_heads, max_num_global_attn_indices, seq_len)
)
# dropout
global_attn_probs = self._global_dropout_layer(global_attn_probs_float, training=training)
global_attn_probs = self._global_dropout_layer(global_attn_probs_float,
training=training)
# global attn output
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
......@@ -952,8 +1037,11 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(global_attn_output),
[batch_size * self._num_heads, max_num_global_attn_indices, self._key_dim],
message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self._num_heads, max_num_global_attn_indices, self._key_dim)}, but is {get_shape_list(global_attn_output)}.",
[batch_size * self._num_heads, max_num_global_attn_indices,
self._key_dim],
message=f"global_attn_output tensor has the wrong size. Size should be "
f"{(batch_size * self._num_heads, max_num_global_attn_indices, self._key_dim)}, "
f"but is {get_shape_list(global_attn_output)}.",
)
global_attn_output = tf.reshape(
......@@ -977,7 +1065,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
)
global_attn_probs = tf.reshape(
global_attn_probs, (batch_size, self._num_heads, max_num_global_attn_indices, seq_len)
global_attn_probs,
(batch_size, self._num_heads, max_num_global_attn_indices, seq_len)
)
attn_output = self._output_dense(attn_output)
......
......@@ -12,25 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the attention layer."""
"""Tests for official.nlp.projects.longformer.longformer_attention."""
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.keras import \
keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.projects.longformer import longformer_attention
from official.modeling.tf_utils import get_shape_list
def _create_mock_attention_data(
num_heads,
key_dim,
value_dim,
q_seq_length,
kv_seq_length,
batch_size,
include_mask=False):
num_heads,
key_dim,
value_dim,
q_seq_length,
kv_seq_length,
batch_size,
include_mask=False):
"""Creates mock testing data.
Args:
......@@ -48,15 +49,15 @@ def _create_mock_attention_data(
value_shape = (batch_size, kv_seq_length, value_dim)
data = dict(
query=tf.random.normal(shape=query_shape),
value=tf.random.normal(shape=value_shape),
key=tf.random.normal(shape=value_shape))
query=tf.random.normal(shape=query_shape),
value=tf.random.normal(shape=value_shape),
key=tf.random.normal(shape=value_shape))
total_seq_length = kv_seq_length
if include_mask:
mask_shape = (batch_size, num_heads, q_seq_length, total_seq_length)
mask_data = np.random.randint(2, size=mask_shape).astype("float32")
mask_data = np.random.randint(2, size=mask_shape).astype('float32')
mask_data = dict(attention_mask=mask_data)
data.update(mask_data)
......@@ -65,6 +66,12 @@ def _create_mock_attention_data(
@keras_parameterized.run_all_keras_modes
class LongformerAttentionTest(keras_parameterized.TestCase):
def setUp(self):
super(LongformerAttentionTest, self).setUp()
np.random.seed(0)
tf.random.set_seed(0)
def _get_hidden_states(self):
return tf.convert_to_tensor(
[
......@@ -116,25 +123,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
def test_diagonalize(self):
hidden_states = self._get_hidden_states()
hidden_states = tf.reshape(hidden_states, (1, 8, 4)) # set seq length = 8, hidden dim = 4
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2)
hidden_states = tf.reshape(hidden_states,
(1, 8, 4)) # set seq length = 8, hidden dim = 4
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(
hidden_states, window_overlap=2)
window_overlap_size = get_shape_list(chunked_hidden_states)[2]
self.assertTrue(window_overlap_size == 4)
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_diagonalize(chunked_hidden_states)
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_diagonalize(
chunked_hidden_states)
self.assertTrue(
get_shape_list(padded_hidden_states)[-1] == get_shape_list(chunked_hidden_states)[-1] + window_overlap_size - 1
get_shape_list(padded_hidden_states)[-1] ==
get_shape_list(chunked_hidden_states)[-1] + window_overlap_size - 1
)
# first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000]
tf.debugging.assert_near(padded_hidden_states[0, 0, 0, :4], chunked_hidden_states[0, 0, 0], rtol=1e-3)
tf.debugging.assert_near(padded_hidden_states[0, 0, 0, 4:], tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3)
tf.debugging.assert_near(padded_hidden_states[0, 0, 0, :4],
chunked_hidden_states[0, 0, 0], rtol=1e-3)
tf.debugging.assert_near(padded_hidden_states[0, 0, 0, 4:],
tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3)
# last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629]
tf.debugging.assert_near(padded_hidden_states[0, 0, -1, 3:], chunked_hidden_states[0, 0, -1], rtol=1e-3)
tf.debugging.assert_near(padded_hidden_states[0, 0, -1, 3:],
chunked_hidden_states[0, 0, -1], rtol=1e-3)
tf.debugging.assert_near(
padded_hidden_states[0, 0, -1, :3], tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3
padded_hidden_states[0, 0, -1, :3],
tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3
)
def test_pad_and_transpose_last_two_dims(self):
......@@ -142,16 +157,21 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
self.assertTrue(get_shape_list(hidden_states), [1, 8, 4])
# pad along seq length dim
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]],
dtype=tf.dtypes.int32)
hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2)
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_transpose_last_two_dims(hidden_states, paddings)
hidden_states = longformer_attention.LongformerAttention._chunk(
hidden_states, window_overlap=2)
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_transpose_last_two_dims(
hidden_states, paddings)
self.assertTrue(get_shape_list(padded_hidden_states) == [1, 1, 8, 5])
expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6)
tf.debugging.assert_near(expected_added_dim,
padded_hidden_states[0, 0, -1, :], rtol=1e-6)
tf.debugging.assert_near(
hidden_states[0, 0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
hidden_states[0, 0, -1, :],
tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
)
def test_mask_invalid_locations(self):
......@@ -159,39 +179,55 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
batch_size = 1
seq_length = 8
hidden_size = 4
hidden_states = tf.reshape(hidden_states, (batch_size, seq_length, hidden_size))
hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2)
hid_states_1 = longformer_attention.LongformerAttention._mask_invalid_locations(hidden_states, 1)
hid_states_2 = longformer_attention.LongformerAttention._mask_invalid_locations(hidden_states, 2)
hid_states_3 = longformer_attention.LongformerAttention._mask_invalid_locations(hidden_states[:, :, :, :3], 2)
hid_states_4 = longformer_attention.LongformerAttention._mask_invalid_locations(hidden_states[:, :, 2:, :], 2)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_1), tf.dtypes.int32)) == 8)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_2), tf.dtypes.int32)) == 24)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_3), tf.dtypes.int32)) == 24)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_4), tf.dtypes.int32)) == 12)
hidden_states = tf.reshape(hidden_states,
(batch_size, seq_length, hidden_size))
hidden_states = longformer_attention.LongformerAttention._chunk(
hidden_states, window_overlap=2)
hid_states_1 = longformer_attention.LongformerAttention._mask_invalid_locations(
hidden_states, 1)
hid_states_2 = longformer_attention.LongformerAttention._mask_invalid_locations(
hidden_states, 2)
hid_states_3 = longformer_attention.LongformerAttention._mask_invalid_locations(
hidden_states[:, :, :, :3], 2)
hid_states_4 = longformer_attention.LongformerAttention._mask_invalid_locations(
hidden_states[:, :, 2:, :], 2)
self.assertTrue(tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_1), tf.dtypes.int32)) == 8)
self.assertTrue(tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_2), tf.dtypes.int32)) == 24)
self.assertTrue(tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_3), tf.dtypes.int32)) == 24)
self.assertTrue(tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_4), tf.dtypes.int32)) == 12)
def test_chunk(self):
hidden_states = self._get_hidden_states()
batch_size = 1
seq_length = 8
hidden_size = 4
hidden_states = tf.reshape(hidden_states, (batch_size, seq_length, hidden_size))
hidden_states = tf.reshape(hidden_states,
(batch_size, seq_length, hidden_size))
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2)
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(
hidden_states, window_overlap=2)
# expected slices across chunk and seq length dim
expected_slice_along_seq_length = tf.convert_to_tensor([0.4983, -0.7584, -1.6944], dtype=tf.dtypes.float32)
expected_slice_along_chunk = tf.convert_to_tensor([0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.dtypes.float32)
expected_slice_along_seq_length = tf.convert_to_tensor(
[0.4983, -0.7584, -1.6944], dtype=tf.dtypes.float32)
expected_slice_along_chunk = tf.convert_to_tensor(
[0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.dtypes.float32)
self.assertTrue(get_shape_list(chunked_hidden_states) == [1, 3, 4, 4])
tf.debugging.assert_near(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, rtol=1e-3)
tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, rtol=1e-3)
tf.debugging.assert_near(chunked_hidden_states[0, :, 0, 0],
expected_slice_along_seq_length, rtol=1e-3)
tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0],
expected_slice_along_chunk, rtol=1e-3)
def test_layer_local_attn(self):
hidden_states = self._get_hidden_states()
batch_size, seq_length, hidden_size = hidden_states.shape
batch_size, seq_length, _ = hidden_states.shape
layer = longformer_attention.LongformerAttention(
num_heads=2,
key_dim=4,
......@@ -203,14 +239,15 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32)
is_index_global_attn = tf.math.greater(attention_mask, 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None])
attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0,
attention_mask[:, :, None, None])
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
output_hidden_states = layer(
hidden_states=hidden_states, attention_mask=attention_mask,
is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
)[0]
self.assertTrue(output_hidden_states.shape, (1, 4, 8))
......@@ -226,32 +263,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
)
hidden_states = self._get_hidden_states()
hidden_states = tf.concat([self._get_hidden_states(), self._get_hidden_states() - 0.5], axis=0)
hidden_states = tf.concat(
[self._get_hidden_states(), self._get_hidden_states() - 0.5], axis=0)
batch_size, seq_length, hidden_size = hidden_states.shape
# create attn mask
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] == 0, 10000.0, attention_mask_1)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1)
attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] == 0, 10000.0, attention_mask_2)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] == 0, 10000.0,
attention_mask_1)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0,
attention_mask_1)
attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] == 0, 10000.0,
attention_mask_2)
attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0)
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
output_hidden_states = layer(
hidden_states=hidden_states, attention_mask=-tf.math.abs(attention_mask),
is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn,
)[0]
hidden_states=hidden_states, attention_mask=-tf.math.abs(attention_mask),
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
)[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
if __name__ == "__main__":
np.random.seed(0)
tf.random.set_seed(0)
if __name__ == '__main__':
tf.test.main()
......@@ -23,29 +23,16 @@ from absl import logging
import tensorflow as tf
from official.nlp.modeling import layers
from official.projects.longformer.longformer_encoder_block import LongformerEncoderBlock
from official.projects.longformer.longformer_encoder_block import \
LongformerEncoderBlock
from official.modeling.tf_utils import get_shape_list
_Initializer = Union[str, tf.keras.initializers.Initializer]
_approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True)
# Transferred from huggingface.longformer.TFLongformerMainLayer & TFLongformerEncoder
class LongformerEncoder(tf.keras.layers.Layer):
"""Bi-directional Transformer-based encoder network.
This network implements a bi-directional Transformer-based encoder as
described in "BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding" (https://arxiv.org/abs/1810.04805). It includes the
embedding lookups and transformer layers, but not the masked language model
or classification task networks.
The default values for this object are taken from the BERT-Base implementation
in "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding".
"""LongformerEncoder
Args:
vocab_size: The size of the token vocabulary.
attention_window: list of ints representing the window size for each layer.
......@@ -85,27 +72,27 @@ class LongformerEncoder(tf.keras.layers.Layer):
"""
def __init__(
self,
vocab_size: int,
attention_window: Union[List[int], int] = 512,
global_attention_size: int = 0,
pad_token_id: int = 1,
hidden_size: int = 768,
num_layers: int = 12,
num_attention_heads: int = 12,
max_sequence_length: int = 512,
type_vocab_size: int = 16,
inner_dim: int = 3072,
inner_activation: Callable[..., Any] = _approx_gelu,
output_dropout: float = 0.1,
attention_dropout: float = 0.1,
initializer: _Initializer = tf.keras.initializers.TruncatedNormal(
stddev=0.02),
output_range: Optional[int] = None,
embedding_width: Optional[int] = None,
embedding_layer: Optional[tf.keras.layers.Layer] = None,
norm_first: bool = False,
**kwargs):
self,
vocab_size: int,
attention_window: Union[List[int], int] = 512,
global_attention_size: int = 0,
pad_token_id: int = 1,
hidden_size: int = 768,
num_layers: int = 12,
num_attention_heads: int = 12,
max_sequence_length: int = 512,
type_vocab_size: int = 16,
inner_dim: int = 3072,
inner_activation: Callable[..., Any] = _approx_gelu,
output_dropout: float = 0.1,
attention_dropout: float = 0.1,
initializer: _Initializer = tf.keras.initializers.TruncatedNormal(
stddev=0.02),
output_range: Optional[int] = None,
embedding_width: Optional[int] = None,
embedding_layer: Optional[tf.keras.layers.Layer] = None,
norm_first: bool = False,
**kwargs):
super().__init__(**kwargs)
# Longformer args
self._attention_window = attention_window
......@@ -120,93 +107,91 @@ class LongformerEncoder(tf.keras.layers.Layer):
if embedding_layer is None:
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
name='word_embeddings')
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
name='word_embeddings')
else:
self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
max_length=max_sequence_length,
name='position_embedding')
initializer=initializer,
max_length=max_sequence_length,
name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
use_one_hot=True,
name='type_embeddings')
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
use_one_hot=True,
name='type_embeddings')
self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
self._embedding_dropout = tf.keras.layers.Dropout(
rate=output_dropout, name='embedding_dropout')
rate=output_dropout, name='embedding_dropout')
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
self._embedding_projection = None
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
name='embedding_projection')
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
name='embedding_projection')
self._transformer_layers = []
self._attention_mask_layer = layers.SelfAttentionMask(
name='self_attention_mask')
name='self_attention_mask')
for i in range(num_layers):
layer = LongformerEncoderBlock(
global_attention_size=global_attention_size,
num_attention_heads=num_attention_heads,
inner_dim=inner_dim,
inner_activation=inner_activation,
# Longformer, instead of passing a list of attention_window, pass a value to sub-block
attention_window=attention_window if isinstance(attention_window, int) else attention_window[i],
layer_id=i,
output_dropout=output_dropout,
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer,
name='transformer/layer_%d' % i)
global_attention_size=global_attention_size,
num_attention_heads=num_attention_heads,
inner_dim=inner_dim,
inner_activation=inner_activation,
attention_window=attention_window[i],
layer_id=i,
output_dropout=output_dropout,
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer,
name=f'transformer/layer_{i}')
self._transformer_layers.append(layer)
self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
name='pooler_transform')
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
name='pooler_transform')
self._config = {
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'inner_dim': inner_dim,
'inner_activation': tf.keras.activations.serialize(activation),
'output_dropout': output_dropout,
'attention_dropout': attention_dropout,
'initializer': tf.keras.initializers.serialize(initializer),
'output_range': output_range,
'embedding_width': embedding_width,
'embedding_layer': embedding_layer,
'norm_first': norm_first,
# Longformer
'attention_window': attention_window,
'global_attention_size': global_attention_size,
'pad_token_id': pad_token_id,
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'inner_dim': inner_dim,
'inner_activation': tf.keras.activations.serialize(activation),
'output_dropout': output_dropout,
'attention_dropout': attention_dropout,
'initializer': tf.keras.initializers.serialize(initializer),
'output_range': output_range,
'embedding_width': embedding_width,
'embedding_layer': embedding_layer,
'norm_first': norm_first,
'attention_window': attention_window,
'global_attention_size': global_attention_size,
'pad_token_id': pad_token_id,
}
self.inputs = dict(
input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32))
input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32))
def call(self, inputs):
word_embeddings = None
......@@ -214,22 +199,23 @@ class LongformerEncoder(tf.keras.layers.Layer):
word_ids = inputs.get('input_word_ids') # input_ids
mask = inputs.get('input_mask') # attention_mask
type_ids = inputs.get('input_type_ids') # token_type_ids
word_embeddings = inputs.get('input_word_embeddings', None) # input_embeds
word_embeddings = inputs.get('input_word_embeddings',
None) # input_embeds
else:
raise ValueError('Unexpected inputs type to %s.' % self.__class__)
raise ValueError(f'Unexpected inputs type to {self.__class__}.')
(
padding_len,
word_ids,
mask,
type_ids,
word_embeddings,
padding_len,
word_ids,
mask,
type_ids,
word_embeddings,
) = self._pad_to_window_size(
word_ids=word_ids,
mask=mask,
type_ids=type_ids,
word_embeddings=word_embeddings,
pad_token_id=self._pad_token_id
word_ids=word_ids,
mask=mask,
type_ids=type_ids,
word_embeddings=word_embeddings,
pad_token_id=self._pad_token_id
)
if word_embeddings is None:
......@@ -247,46 +233,47 @@ class LongformerEncoder(tf.keras.layers.Layer):
batch_size, seq_len = get_shape_list(mask)
# create masks with fixed len global_attention_size
mask = tf.transpose(tf.concat(values=[tf.ones((self._global_attention_size, batch_size), tf.int32) * 2,
tf.transpose(mask)[self._global_attention_size:]], axis=0))
mask = tf.transpose(tf.concat(
values=[tf.ones((self._global_attention_size, batch_size), tf.int32) * 2,
tf.transpose(mask)[self._global_attention_size:]], axis=0))
is_index_masked = tf.math.less(mask, 1)
is_index_global_attn = tf.transpose(tf.concat(values=[
tf.ones((self._global_attention_size, batch_size), tf.bool), tf.zeros((seq_len - self._global_attention_size,
batch_size), tf.bool)
tf.ones((self._global_attention_size, batch_size), tf.bool),
tf.zeros((seq_len - self._global_attention_size,
batch_size), tf.bool)
], axis=0))
is_global_attn = self._global_attention_size > 0
# Longformer
attention_mask = mask
extended_attention_mask = tf.reshape(
attention_mask, (tf.shape(mask)[0], tf.shape(mask)[1], 1, 1)
attention_mask, (tf.shape(mask)[0], tf.shape(mask)[1], 1, 1)
)
attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0
attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask),
tf.dtypes.float32) * -10000.0
encoder_outputs = []
x = embeddings
# TFLongformerEncoder
for i, layer in enumerate(self._transformer_layers):
for layer in self._transformer_layers:
x = layer([
x,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn])
x,
attention_mask,
is_index_masked,
is_index_global_attn])
encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1]
if padding_len > 0:
last_encoder_output = last_encoder_output[:, :-padding_len]
last_encoder_output = last_encoder_output[:, :-padding_len]
first_token_tensor = last_encoder_output[:, 0, :]
pooled_output = self._pooler_layer(first_token_tensor)
return dict(
sequence_output=last_encoder_output,
pooled_output=pooled_output,
encoder_outputs=encoder_outputs)
sequence_output=last_encoder_output,
pooled_output=pooled_output,
encoder_outputs=encoder_outputs)
def get_embedding_table(self):
return self._embedding_layer.embeddings
......@@ -311,36 +298,36 @@ class LongformerEncoder(tf.keras.layers.Layer):
def from_config(cls, config, custom_objects=None):
if 'embedding_layer' in config and config['embedding_layer'] is not None:
warn_string = (
'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.')
'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.')
print('WARNING: ' + warn_string)
logging.warn(warn_string)
return cls(**config)
def _pad_to_window_size(
self,
word_ids,
mask,
type_ids,
word_embeddings,
pad_token_id,
self,
word_ids,
mask,
type_ids,
word_embeddings,
pad_token_id,
):
"""A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
# padding
attention_window = (
self._attention_window if isinstance(self._attention_window, int) else max(self._attention_window)
)
attention_window = max(self._attention_window)
assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}"
assert attention_window % 2 == 0, \
f'`attention_window` should be an even value. Given {attention_window}'
input_shape = get_shape_list(word_ids) if word_ids is not None else get_shape_list(word_embeddings)
input_shape = get_shape_list(
word_ids) if word_ids is not None else get_shape_list(word_embeddings)
batch_size, seq_len = input_shape[:2]
if seq_len is not None:
padding_len = (attention_window - seq_len % attention_window) % attention_window
padding_len = (attention_window -
seq_len % attention_window) % attention_window
else:
padding_len = 0
......@@ -355,14 +342,17 @@ class LongformerEncoder(tf.keras.layers.Layer):
word_embeddings_padding = self._embedding_layer(word_ids_padding)
return tf.concat([word_embeddings, word_embeddings_padding], axis=-2)
word_embeddings = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: word_embeddings)
word_embeddings = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings,
lambda: word_embeddings)
mask = tf.pad(mask, paddings, constant_values=False) # no attention on the padding tokens
token_type_ids = tf.pad(type_ids, paddings, constant_values=0) # pad with token_type_id = 0
mask = tf.pad(mask, paddings,
constant_values=False) # no attention on the padding tokens
token_type_ids = tf.pad(type_ids, paddings,
constant_values=0) # pad with token_type_id = 0
return (
padding_len,
word_ids,
mask,
token_type_ids,
word_embeddings,)
padding_len,
word_ids,
mask,
token_type_ids,
word_embeddings,)
......@@ -17,49 +17,13 @@ Longformer attention layer. Modified From huggingface/transformers
"""
import tensorflow as tf
from official.projects.longformer.longformer_attention import LongformerAttention
from official.projects.longformer.longformer_attention import \
LongformerAttention
@tf.keras.utils.register_keras_serializable(package="Text")
class LongformerEncoderBlock(tf.keras.layers.Layer):
"""TransformerEncoderBlock layer.
This layer implements the Transformer Encoder from
"Attention Is All You Need". (https://arxiv.org/abs/1706.03762),
which combines a `tf.keras.layers.MultiHeadAttention` layer with a
two-layer feedforward network.
References:
[Attention Is All You Need](https://arxiv.org/abs/1706.03762)
[BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding](https://arxiv.org/abs/1810.04805)
"""
def __init__(self,
global_attention_size,
num_attention_heads,
inner_dim,
inner_activation,
# Longformer
attention_window,
layer_id=0,
output_range=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
output_dropout=0.0,
attention_dropout=0.0,
inner_dropout=0.0,
attention_initializer=None,
attention_axes=None,
**kwargs):
"""Initializes `TransformerEncoderBlock`.
"""LongformerEncoderBlock.
Args:
num_attention_heads: Number of attention heads.
......@@ -94,6 +58,32 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/
"""
def __init__(self,
global_attention_size,
num_attention_heads,
inner_dim,
inner_activation,
# Longformer
attention_window,
layer_id=0,
output_range=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
output_dropout=0.0,
attention_dropout=0.0,
inner_dropout=0.0,
attention_initializer=None,
attention_axes=None,
**kwargs):
super().__init__(**kwargs)
self.global_attention_size = global_attention_size
......@@ -121,7 +111,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
self._inner_dropout = inner_dropout
if attention_initializer:
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_axes = attention_axes
......@@ -133,58 +123,58 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
input_tensor_shape = tf.TensorShape(input_shape[0])
else:
raise ValueError(
"The type of input shape argument is not supported, got: %s" %
type(input_shape))
f"The type of input shape argument is not supported, got: "
f"{type(input_shape)}")
einsum_equation = "abc,cd->abd"
if len(input_tensor_shape.as_list()) > 3:
einsum_equation = "...bc,cd->...bd"
hidden_size = input_tensor_shape[-1]
if hidden_size % self._num_heads != 0:
raise ValueError(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
f"The input size ({hidden_size}) is not a multiple of the number of attention "
f"heads ({self._num_heads})")
self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
# TFLongformerSelfAttention + TFLongformerSelfOutput.dense
self._attention_layer = LongformerAttention(
# Longformer
layer_id=self._layer_id,
global_attention_size=self.global_attention_size,
attention_window=self._attention_window,
num_heads=self._num_heads,
key_dim=self._attention_head_size,
dropout=self._attention_dropout,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
attention_axes=self._attention_axes,
name="self_attention",
**common_kwargs)
# Longformer
layer_id=self._layer_id,
global_attention_size=self.global_attention_size,
attention_window=self._attention_window,
num_heads=self._num_heads,
key_dim=self._attention_head_size,
dropout=self._attention_dropout,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
attention_axes=self._attention_axes,
name="self_attention",
**common_kwargs)
# TFLongformerSelfOutput.dropout
self._attention_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
# TFLongformerSelfOutput.Layernorm
self._attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32))
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32))
# TFLongformerIntermediate
# TFLongformerIntermediate.dense
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
name="intermediate",
**common_kwargs)
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge
......@@ -193,72 +183,72 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
policy = tf.float32
# TFLongformerIntermediate.intermediate_act_fn
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._inner_activation, dtype=policy)
self._inner_activation, dtype=policy)
# ???
self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout)
rate=self._inner_dropout)
# TFLongformerOutput
# TFLongformerOutput.dense
self._output_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=self._kernel_initializer,
**common_kwargs)
einsum_equation,
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=self._kernel_initializer,
**common_kwargs)
# TFLongformerOutput.dropout
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
# TFLongformerOutput.layernorm
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
super(LongformerEncoderBlock, self).build(input_shape)
super().build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self._num_heads,
"inner_dim":
self._inner_dim,
"inner_activation":
self._inner_activation,
"output_dropout":
self._output_dropout_rate,
"attention_dropout":
self._attention_dropout_rate,
"output_range":
self._output_range,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"inner_dropout":
self._inner_dropout,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer),
"attention_axes": self._attention_axes,
"num_attention_heads":
self._num_heads,
"inner_dim":
self._inner_dim,
"inner_activation":
self._inner_activation,
"output_dropout":
self._output_dropout_rate,
"attention_dropout":
self._attention_dropout_rate,
"output_range":
self._output_range,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"inner_dropout":
self._inner_dropout,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer),
"attention_axes": self._attention_axes,
}
base_config = super(LongformerEncoderBlock, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
......@@ -277,26 +267,23 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
An output tensor with the same dimensions as input/query tensor.
"""
if isinstance(inputs, (list, tuple)):
if len(inputs) == 5:
if len(inputs) == 4:
(
input_tensor,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn
) = inputs
key_value = None
elif len(inputs) == 6:
elif len(inputs) == 5:
assert False # No key_value
else:
raise ValueError("Unexpected inputs to %s with length at %d" %
(self.__class__, len(inputs)))
raise ValueError(f"Unexpected inputs to {self.__class__} with length at {len(inputs)}")
else:
input_tensor = inputs
attention_mask = None
is_index_masked = None
is_index_global_attn = None
is_global_attn = None
key_value = None
if self._output_range:
......@@ -325,11 +312,10 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
# attention_output = self._attention_layer(
# query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output = self._attention_layer(
hidden_states=target_tensor,
attention_mask=attention_mask,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn
hidden_states=target_tensor,
attention_mask=attention_mask,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
)
# TFLongformerAttention.TFLongformerSelfOutput.* - {.dense}
attention_output = self._attention_dropout(attention_output)
......
......@@ -12,44 +12,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.projects.bigbird.encoder."""
"""Tests for official.nlp.projects.longformer.longformer_encoder."""
import numpy as np
import tensorflow as tf
from absl.testing import parameterized
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.keras import \
keras_parameterized # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.distribute import combinations
from official.projects.longformer.longformer_encoder import LongformerEncoder
@keras_parameterized.run_all_keras_modes
class LongformerEncoderTest(keras_parameterized.TestCase):
def setUp(self):
super(LongformerEncoderTest, self).setUp()
np.random.seed(0)
tf.random.set_seed(0)
@combinations.generate(combinations.combine(
attention_window=[32, 128], global_attention_size=[0, 1, 2]))
attention_window=[32, 128], global_attention_size=[0, 1, 2]))
def test_encoder(self, attention_window, global_attention_size):
sequence_length = 128
batch_size = 2
vocab_size = 1024
hidden_size=256
hidden_size = 256
network = LongformerEncoder(
global_attention_size=global_attention_size,
vocab_size=vocab_size,
attention_window=attention_window,
hidden_size=hidden_size,
num_layers=1,
num_attention_heads=4,
max_sequence_length=512)
word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length), dtype=np.int32)
mask_data = np.random.randint(2, size=(batch_size, sequence_length), dtype=np.int32)
type_id_data = np.random.randint(2, size=(batch_size, sequence_length), dtype=np.int32)
global_attention_size=global_attention_size,
vocab_size=vocab_size,
attention_window=[attention_window],
hidden_size=hidden_size,
num_layers=1,
num_attention_heads=4,
max_sequence_length=512)
word_id_data = np.random.randint(vocab_size,
size=(batch_size, sequence_length),
dtype=np.int32)
mask_data = np.random.randint(2, size=(batch_size, sequence_length),
dtype=np.int32)
type_id_data = np.random.randint(2, size=(batch_size, sequence_length),
dtype=np.int32)
inputs = {
'input_word_ids': word_id_data,
'input_mask': mask_data,
'input_type_ids': type_id_data,
'input_word_ids': word_id_data,
'input_mask': mask_data,
'input_type_ids': type_id_data,
}
outputs = network(inputs)
self.assertEqual(outputs["sequence_output"].shape,
self.assertEqual(outputs['sequence_output'].shape,
(batch_size, sequence_length, hidden_size))
@combinations.generate(combinations.combine(
......@@ -60,26 +71,30 @@ class LongformerEncoderTest(keras_parameterized.TestCase):
vocab_size = 1024
hidden_size = 256
network = LongformerEncoder(
global_attention_size=global_attention_size,
vocab_size=vocab_size,
attention_window=32,
hidden_size=hidden_size,
num_layers=1,
num_attention_heads=4,
max_sequence_length=512,
norm_first=norm_first)
word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length), dtype=np.int32)
mask_data = np.random.randint(2, size=(batch_size, sequence_length), dtype=np.int32)
type_id_data = np.random.randint(2, size=(batch_size, sequence_length), dtype=np.int32)
global_attention_size=global_attention_size,
vocab_size=vocab_size,
attention_window=[32],
hidden_size=hidden_size,
num_layers=1,
num_attention_heads=4,
max_sequence_length=512,
norm_first=norm_first)
word_id_data = np.random.randint(vocab_size,
size=(batch_size, sequence_length),
dtype=np.int32)
mask_data = np.random.randint(2, size=(batch_size, sequence_length),
dtype=np.int32)
type_id_data = np.random.randint(2, size=(batch_size, sequence_length),
dtype=np.int32)
inputs = {
'input_word_ids': word_id_data,
'input_mask': mask_data,
'input_type_ids': type_id_data,
'input_word_ids': word_id_data,
'input_mask': mask_data,
'input_type_ids': type_id_data,
}
outputs = network(inputs)
self.assertEqual(outputs["sequence_output"].shape,
self.assertEqual(outputs['sequence_output'].shape,
(batch_size, sequence_length, hidden_size))
if __name__ == "__main__":
tf.test.main()
\ No newline at end of file
if __name__ == '__main__':
tf.test.main()
......@@ -34,84 +34,90 @@ AdamWeightDecay = optimization.AdamWeightDecayConfig
PolynomialLr = optimization.PolynomialLrConfig
PolynomialWarmupConfig = optimization.PolynomialWarmupConfig
@dataclasses.dataclass
class LongformerOptimizationConfig(optimization.OptimizationConfig):
optimizer: optimization.OptimizerConfig = optimization.OptimizerConfig(
type="adamw",
adamw=AdamWeightDecay(
weight_decay_rate=0.01,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
epsilon=1e-6))
type='adamw',
adamw=AdamWeightDecay(
weight_decay_rate=0.01,
exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'],
epsilon=1e-6))
learning_rate: optimization.LrConfig = optimization.LrConfig(
type="polynomial",
polynomial=PolynomialLr(
initial_learning_rate=1e-4,
decay_steps=1000000,
end_learning_rate=0.0))
type='polynomial',
polynomial=PolynomialLr(
initial_learning_rate=1e-4,
decay_steps=1000000,
end_learning_rate=0.0))
warmup: optimization.WarmupConfig = optimization.WarmupConfig(
type="polynomial", polynomial=PolynomialWarmupConfig(warmup_steps=10000))
type='polynomial', polynomial=PolynomialWarmupConfig(warmup_steps=10000))
@exp_factory.register_config_factory('longformer/pretraining')
def longformer_pretraining() -> cfg.ExperimentConfig:
"""BERT pretraining experiment."""
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(enable_xla=True),
task=masked_lm.MaskedLMConfig(
model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
type="any", any=LongformerEncoderConfig()),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
]
),
train_data=pretrain_dataloader.BertPretrainDataConfig(use_v2_feature_names=True),
validation_data=pretrain_dataloader.BertPretrainDataConfig(use_v2_feature_names=True,
is_training=False)),
trainer=cfg.TrainerConfig(
optimizer_config=LongformerOptimizationConfig(), train_steps=1000000),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
runtime=cfg.RuntimeConfig(enable_xla=True),
task=masked_lm.MaskedLMConfig(
model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
type="any", any=LongformerEncoderConfig()),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=768, num_classes=2, dropout_rate=0.1,
name='next_sentence')
]
),
train_data=pretrain_dataloader.BertPretrainDataConfig(
use_v2_feature_names=True),
validation_data=pretrain_dataloader.BertPretrainDataConfig(
use_v2_feature_names=True,
is_training=False)),
trainer=cfg.TrainerConfig(
optimizer_config=LongformerOptimizationConfig(), train_steps=1000000),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('longformer/glue')
def longformer_glue() -> cfg.ExperimentConfig:
config = cfg.ExperimentConfig(
task=sentence_prediction.SentencePredictionConfig(
model=sentence_prediction.ModelConfig(
encoder=encoders.EncoderConfig(
type="any", any=LongformerEncoderConfig())),
train_data=sentence_prediction_dataloader
.SentencePredictionDataConfig(),
validation_data=sentence_prediction_dataloader
.SentencePredictionDataConfig(
is_training=False, drop_remainder=False)),
trainer=cfg.TrainerConfig(
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate':
0.01,
'exclude_from_weight_decay':
['LayerNorm', 'layer_norm', 'bias'],
}
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 3e-5,
'end_learning_rate': 0.0,
}
},
'warmup': {
'type': 'polynomial'
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
task=sentence_prediction.SentencePredictionConfig(
model=sentence_prediction.ModelConfig(
encoder=encoders.EncoderConfig(
type="any", any=LongformerEncoderConfig())),
train_data=sentence_prediction_dataloader
.SentencePredictionDataConfig(),
validation_data=sentence_prediction_dataloader
.SentencePredictionDataConfig(
is_training=False, drop_remainder=False)),
trainer=cfg.TrainerConfig(
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate':
0.01,
'exclude_from_weight_decay':
['LayerNorm', 'layer_norm', 'bias'],
}
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 3e-5,
'end_learning_rate': 0.0,
}
},
'warmup': {
'type': 'polynomial'
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
......@@ -24,7 +24,6 @@ from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.projects.longformer import longformer_experiments
FLAGS = flags.FLAGS
......@@ -43,23 +42,24 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
performance.set_mixed_precision_policy(
params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu,
**params.runtime.model_parallelism())
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu,
**params.runtime.model_parallelism())
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment