Commit dc588495 authored by Zihan Wang's avatar Zihan Wang
Browse files

use tf-utils.get_shape_list

parent 8c430b98
......@@ -34,28 +34,9 @@ from keras.layers import einsum_dense
from keras.utils import tf_utils
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export
from official.modeling.tf_utils import get_shape_list
from typing import Dict, List, Optional, Union
def shape_list(tensor: tf.Tensor) -> List[int]:
"""
Deal with dynamic shape in tensorflow cleanly.
Args:
tensor (:obj:`tf.Tensor`): The tensor we want the shape of.
Returns:
:obj:`List[int]`: The shape of the tensor as a list.
"""
dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None):
return dynamic
static = tensor.shape.as_list()
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
_CHR_IDX = string.ascii_lowercase
def _build_attention_equation(rank, attn_axes):
......@@ -292,7 +273,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim))) # (B, T, N, key_dim)
batch_size, seq_len, num_heads, head_dim = shape_list(query)
batch_size, seq_len, num_heads, head_dim = get_shape_list(query)
# attn_probs = (batch_size, seq_len, num_heads, window*2+1)
attn_scores = self._sliding_chunks_query_key_matmul(
......@@ -301,7 +282,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul(
tf.ones(shape_list(attention_mask)),
tf.ones(get_shape_list(attention_mask)),
attention_mask,
self._one_sided_attn_window_size,
)
......@@ -311,9 +292,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_scores),
get_shape_list(attn_scores),
[batch_size, seq_len, self._num_heads, self._one_sided_attn_window_size * 2 + 1],
message=f"attn_probs should be of size ({batch_size}, {seq_len}, {num_heads}, {self._one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}",
message=f"attn_probs should be of size ({batch_size}, {seq_len}, {num_heads}, {self._one_sided_attn_window_size * 2 + 1}), but is of size {get_shape_list(attn_scores)}",
)
# compute global attn indices required through out forward fn
......@@ -356,7 +337,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
attn_probs = tf.where(
masked_index,
tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype),
tf.zeros(get_shape_list(masked_index), dtype=attn_probs.dtype),
attn_probs,
)
......@@ -364,9 +345,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if layer_head_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
get_shape_list(layer_head_mask),
[self._num_heads],
message=f"Head mask for a single layer should be of size {(self._num_heads)}, but is {shape_list(layer_head_mask)}",
message=f"Head mask for a single layer should be of size {(self._num_heads)}, but is {get_shape_list(layer_head_mask)}",
)
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
......@@ -391,7 +372,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
get_shape_list(attn_output),
[batch_size, seq_len, self._num_heads, head_dim],
message="Unexpected size",
)
......@@ -432,7 +413,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
attn_probs = tf.where(
masked_global_attn_index,
tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype),
tf.zeros(get_shape_list(masked_global_attn_index), dtype=attn_probs.dtype),
attn_probs,
)
......@@ -455,7 +436,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an
overlap of size window_overlap
"""
batch_size, seq_len, num_heads, head_dim = shape_list(query)
batch_size, seq_len, num_heads, head_dim = get_shape_list(query)
if tf.executing_eagerly():
tf.debugging.assert_equal(
......@@ -464,9 +445,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
)
tf.debugging.assert_equal(
shape_list(query),
shape_list(key),
message=f"Shape of query and key should be equal, but got query: {shape_list(query)} and key: {shape_list(key)}",
get_shape_list(query),
get_shape_list(key),
message=f"Shape of query and key should be equal, but got query: {get_shape_list(query)} and key: {get_shape_list(key)}",
)
chunks_count = seq_len // window_overlap - 1
......@@ -574,7 +555,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# pad to full matrix
padding = tf.convert_to_tensor(
[[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]]
[[0, get_shape_list(input_tensor)[1] - window_overlap], [0, get_shape_list(input_tensor)[3] - window_overlap - 1]]
)
# create lower mask
......@@ -584,7 +565,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
# broadcast to full matrix
mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1))
mask_4d = tf.tile(mask_2d[None, :, None, :], (get_shape_list(input_tensor)[0], 1, 1, 1))
# inf tensor used for masking
inf_tensor = -float("inf") * tf.ones_like(input_tensor)
......@@ -600,7 +581,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
same shape as `attn_probs`
"""
batch_size, seq_len, num_heads, head_dim = shape_list(value)
batch_size, seq_len, num_heads, head_dim = get_shape_list(value)
if tf.executing_eagerly():
tf.debugging.assert_equal(
......@@ -609,12 +590,12 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
message="Seq_len has to be multiple of 2 * window_overlap",
)
tf.debugging.assert_equal(
shape_list(attn_probs)[:3],
shape_list(value)[:3],
get_shape_list(attn_probs)[:3],
get_shape_list(value)[:3],
message="value and attn_probs must have same dims (except head_dim)",
)
tf.debugging.assert_equal(
shape_list(attn_probs)[3],
get_shape_list(attn_probs)[3],
2 * window_overlap + 1,
message="attn_probs last dim has to be 2 * window_overlap + 1",
)
......@@ -644,7 +625,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
frame_size = 3 * window_overlap * head_dim
frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count
frame_hop_size = (get_shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count
chunked_value = tf.signal.frame(
tf.reshape(padded_value, (batch_size * num_heads, -1)),
frame_size,
......@@ -657,7 +638,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(chunked_value),
get_shape_list(chunked_value),
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
message="Chunked value has the wrong shape",
)
......@@ -677,7 +658,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
hidden_states_padded = tf.pad(
hidden_states_padded, paddings
) # padding value is not important because it will be overwritten
batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded)
batch_size, chunk_size, seq_length, hidden_dim = get_shape_list(hidden_states_padded)
hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length))
return hidden_states_padded
......@@ -700,7 +681,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
"""
total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states)
total_num_heads, num_chunks, window_overlap, hidden_dim = get_shape_list(chunked_hidden_states)
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]])
chunked_hidden_states = tf.pad(
chunked_hidden_states, paddings
......@@ -722,7 +703,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
@staticmethod
def _chunk(hidden_states, window_overlap):
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
batch_size, seq_length, hidden_dim = shape_list(hidden_states)
batch_size, seq_length, hidden_dim = get_shape_list(hidden_states)
num_output_chunks = 2 * (seq_length // (2 * window_overlap)) - 1
# define frame size and frame stride (similar to convolution)
......@@ -735,9 +716,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(chunked_hidden_states),
get_shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size],
message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}.",
message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size, frame_size, num_output_chunks]}, but got {get_shape_list(chunked_hidden_states)}.",
)
chunked_hidden_states = tf.reshape(
......@@ -752,7 +733,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
"""compute global attn indices required throughout forward pass"""
# All global attention size are fixed through global_attention_size
batch_size, seq_len = shape_list(is_index_global_attn)
batch_size, seq_len = get_shape_list(is_index_global_attn)
max_num_global_attn_indices = global_attention_size
......@@ -787,7 +768,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
):
batch_size = shape_list(key_vectors)[0]
batch_size = get_shape_list(key_vectors)[0]
# select global key vectors
global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero)
......@@ -809,8 +790,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# (batch_size, max_num_global_attn_indices, seq_len, num_heads)
attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2))
mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
shape_list(attn_probs_from_global_key_trans)[-2:]
mask_shape = (get_shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
get_shape_list(attn_probs_from_global_key_trans)[-2:]
)
mask = tf.ones(mask_shape) * -10000.0
mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype)
......@@ -838,11 +819,11 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
):
batch_size = shape_list(attn_probs)[0]
batch_size = get_shape_list(attn_probs)[0]
# cut local attn probs to global only
attn_probs_only_global = attn_probs[:, :, :, :max_num_global_attn_indices]
# attn_probs_only_global = tf.slice(attn_probs, [0, 0, 0, 0], shape_list(attn_probs)[: -1] + [max_num_global_attn_indices])
# attn_probs_only_global = tf.slice(attn_probs, [0, 0, 0, 0], get_shape_list(attn_probs)[: -1] + [max_num_global_attn_indices])
# select global value vectors
global_value_vectors = tf.gather_nd(value_vectors, is_index_global_attn_nonzero)
......@@ -863,7 +844,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
attn_output_only_global = tf.einsum("blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global)
# reshape attn probs
attn_probs_without_global = attn_probs[:, :, :, max_num_global_attn_indices:]
# attn_probs_without_global = tf.slice(attn_probs, [0, 0, 0, max_num_global_attn_indices], shape_list(attn_probs)[: -1] + [shape_list(attn_probs)[-1] - max_num_global_attn_indices])
# attn_probs_without_global = tf.slice(attn_probs, [0, 0, 0, max_num_global_attn_indices], get_shape_list(attn_probs)[: -1] + [get_shape_list(attn_probs)[-1] - max_num_global_attn_indices])
# compute attn output with global
attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
......@@ -884,7 +865,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
is_index_masked,
training,
):
batch_size, seq_len = shape_list(hidden_states)[:2]
batch_size, seq_len = get_shape_list(hidden_states)[:2]
# prepare global hidden states
global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero)
......@@ -912,9 +893,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(global_attn_scores),
get_shape_list(global_attn_scores),
[batch_size * self._num_heads, max_num_global_attn_indices, seq_len],
message=f"global_attn_scores have the wrong size. Size should be {(batch_size * self._num_heads, max_num_global_attn_indices, seq_len)}, but is {shape_list(global_attn_scores)}.",
message=f"global_attn_scores have the wrong size. Size should be {(batch_size * self._num_heads, max_num_global_attn_indices, seq_len)}, but is {get_shape_list(global_attn_scores)}.",
)
global_attn_scores = tf.reshape(
......@@ -922,8 +903,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
(batch_size, self._num_heads, max_num_global_attn_indices, seq_len),
)
global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3))
mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
shape_list(global_attn_scores_trans)[-2:]
mask_shape = (get_shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
get_shape_list(global_attn_scores_trans)[-2:]
)
global_attn_mask = tf.ones(mask_shape) * -10000.0
global_attn_mask = tf.cast(global_attn_mask, dtype=global_attn_scores_trans.dtype)
......@@ -937,7 +918,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
# mask global attn scores
attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1))
attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, get_shape_list(global_attn_scores)[1], 1, 1))
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
global_attn_scores = tf.reshape(
global_attn_scores,
......@@ -951,9 +932,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if layer_head_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
get_shape_list(layer_head_mask),
[self._num_heads],
message=f"Head mask for a single layer should be of size {(self._num_heads)}, but is {shape_list(layer_head_mask)}",
message=f"Head mask for a single layer should be of size {(self._num_heads)}, but is {get_shape_list(layer_head_mask)}",
)
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
global_attn_probs_float, (batch_size, self._num_heads, max_num_global_attn_indices, seq_len)
......@@ -970,9 +951,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(global_attn_output),
get_shape_list(global_attn_output),
[batch_size * self._num_heads, max_num_global_attn_indices, self._key_dim],
message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self._num_heads, max_num_global_attn_indices, self._key_dim)}, but is {shape_list(global_attn_output)}.",
message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self._num_heads, max_num_global_attn_indices, self._key_dim)}, but is {get_shape_list(global_attn_output)}.",
)
global_attn_output = tf.reshape(
......@@ -987,7 +968,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
)
nonzero_global_attn_output = tf.reshape(
nonzero_global_attn_output,
(shape_list(is_local_index_global_attn_nonzero)[0], -1),
(get_shape_list(is_local_index_global_attn_nonzero)[0], -1),
)
# overwrite values with global attention
......
......@@ -20,6 +20,7 @@ import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.projects.longformer import longformer_attention
from official.modeling.tf_utils import get_shape_list
def _create_mock_attention_data(
......@@ -117,13 +118,13 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
hidden_states = self._get_hidden_states()
hidden_states = tf.reshape(hidden_states, (1, 8, 4)) # set seq length = 8, hidden dim = 4
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2)
window_overlap_size = longformer_attention.shape_list(chunked_hidden_states)[2]
window_overlap_size = get_shape_list(chunked_hidden_states)[2]
self.assertTrue(window_overlap_size == 4)
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_diagonalize(chunked_hidden_states)
self.assertTrue(
longformer_attention.shape_list(padded_hidden_states)[-1] == longformer_attention.shape_list(chunked_hidden_states)[-1] + window_overlap_size - 1
get_shape_list(padded_hidden_states)[-1] == get_shape_list(chunked_hidden_states)[-1] + window_overlap_size - 1
)
# first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000]
......@@ -138,14 +139,14 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
def test_pad_and_transpose_last_two_dims(self):
hidden_states = self._get_hidden_states()
self.assertTrue(longformer_attention.shape_list(hidden_states), [1, 8, 4])
self.assertTrue(get_shape_list(hidden_states), [1, 8, 4])
# pad along seq length dim
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2)
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_transpose_last_two_dims(hidden_states, paddings)
self.assertTrue(longformer_attention.shape_list(padded_hidden_states) == [1, 1, 8, 5])
self.assertTrue(get_shape_list(padded_hidden_states) == [1, 1, 8, 5])
expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6)
......@@ -184,7 +185,7 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
expected_slice_along_seq_length = tf.convert_to_tensor([0.4983, -0.7584, -1.6944], dtype=tf.dtypes.float32)
expected_slice_along_chunk = tf.convert_to_tensor([0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.dtypes.float32)
self.assertTrue(longformer_attention.shape_list(chunked_hidden_states) == [1, 3, 4, 4])
self.assertTrue(get_shape_list(chunked_hidden_states) == [1, 3, 4, 4])
tf.debugging.assert_near(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, rtol=1e-3)
tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, rtol=1e-3)
......
......@@ -24,25 +24,8 @@ import tensorflow as tf
from official.nlp.modeling import layers
from official.projects.longformer.longformer_encoder_block import LongformerEncoderBlock
from official.modeling.tf_utils import get_shape_list
def shape_list(tensor: tf.Tensor) -> List[int]:
"""
Deal with dynamic shape in tensorflow cleanly.
Args:
tensor (:obj:`tf.Tensor`): The tensor we want the shape of.
Returns:
:obj:`List[int]`: The shape of the tensor as a list.
"""
dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None):
return dynamic
static = tensor.shape.as_list()
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
_Initializer = Union[str, tf.keras.initializers.Initializer]
......@@ -262,7 +245,7 @@ class LongformerEncoder(tf.keras.layers.Layer):
if self._embedding_projection is not None:
embeddings = self._embedding_projection(embeddings)
batch_size, seq_len = shape_list(mask)
batch_size, seq_len = get_shape_list(mask)
# create masks with fixed len global_attention_size
mask = tf.transpose(tf.concat(values=[tf.ones((self._global_attention_size, batch_size), tf.int32) * 2,
tf.transpose(mask)[self._global_attention_size:]], axis=0))
......@@ -353,7 +336,7 @@ class LongformerEncoder(tf.keras.layers.Layer):
assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}"
input_shape = shape_list(word_ids) if word_ids is not None else shape_list(word_embeddings)
input_shape = get_shape_list(word_ids) if word_ids is not None else get_shape_list(word_embeddings)
batch_size, seq_len = input_shape[:2]
if seq_len is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment