Commit dc588495 authored by Zihan Wang's avatar Zihan Wang
Browse files

use tf-utils.get_shape_list

parent 8c430b98
...@@ -34,28 +34,9 @@ from keras.layers import einsum_dense ...@@ -34,28 +34,9 @@ from keras.layers import einsum_dense
from keras.utils import tf_utils from keras.utils import tf_utils
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
from official.modeling.tf_utils import get_shape_list
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
def shape_list(tensor: tf.Tensor) -> List[int]:
"""
Deal with dynamic shape in tensorflow cleanly.
Args:
tensor (:obj:`tf.Tensor`): The tensor we want the shape of.
Returns:
:obj:`List[int]`: The shape of the tensor as a list.
"""
dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None):
return dynamic
static = tensor.shape.as_list()
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
_CHR_IDX = string.ascii_lowercase _CHR_IDX = string.ascii_lowercase
def _build_attention_equation(rank, attn_axes): def _build_attention_equation(rank, attn_axes):
...@@ -292,7 +273,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -292,7 +273,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# XLA performance, but may introduce slight numeric differences in # XLA performance, but may introduce slight numeric differences in
# the Transformer attention head. # the Transformer attention head.
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim))) # (B, T, N, key_dim) query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim))) # (B, T, N, key_dim)
batch_size, seq_len, num_heads, head_dim = shape_list(query) batch_size, seq_len, num_heads, head_dim = get_shape_list(query)
# attn_probs = (batch_size, seq_len, num_heads, window*2+1) # attn_probs = (batch_size, seq_len, num_heads, window*2+1)
attn_scores = self._sliding_chunks_query_key_matmul( attn_scores = self._sliding_chunks_query_key_matmul(
...@@ -301,7 +282,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -301,7 +282,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# diagonal mask with zeros everywhere and -inf inplace of padding # diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul( diagonal_mask = self._sliding_chunks_query_key_matmul(
tf.ones(shape_list(attention_mask)), tf.ones(get_shape_list(attention_mask)),
attention_mask, attention_mask,
self._one_sided_attn_window_size, self._one_sided_attn_window_size,
) )
...@@ -311,9 +292,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -311,9 +292,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if tf.executing_eagerly(): if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_scores), get_shape_list(attn_scores),
[batch_size, seq_len, self._num_heads, self._one_sided_attn_window_size * 2 + 1], [batch_size, seq_len, self._num_heads, self._one_sided_attn_window_size * 2 + 1],
message=f"attn_probs should be of size ({batch_size}, {seq_len}, {num_heads}, {self._one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}", message=f"attn_probs should be of size ({batch_size}, {seq_len}, {num_heads}, {self._one_sided_attn_window_size * 2 + 1}), but is of size {get_shape_list(attn_scores)}",
) )
# compute global attn indices required through out forward fn # compute global attn indices required through out forward fn
...@@ -356,7 +337,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -356,7 +337,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
attn_probs = tf.where( attn_probs = tf.where(
masked_index, masked_index,
tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype), tf.zeros(get_shape_list(masked_index), dtype=attn_probs.dtype),
attn_probs, attn_probs,
) )
...@@ -364,9 +345,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -364,9 +345,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if layer_head_mask is not None: if layer_head_mask is not None:
if tf.executing_eagerly(): if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(layer_head_mask), get_shape_list(layer_head_mask),
[self._num_heads], [self._num_heads],
message=f"Head mask for a single layer should be of size {(self._num_heads)}, but is {shape_list(layer_head_mask)}", message=f"Head mask for a single layer should be of size {(self._num_heads)}, but is {get_shape_list(layer_head_mask)}",
) )
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
...@@ -391,7 +372,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -391,7 +372,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if tf.executing_eagerly(): if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_output), get_shape_list(attn_output),
[batch_size, seq_len, self._num_heads, head_dim], [batch_size, seq_len, self._num_heads, head_dim],
message="Unexpected size", message="Unexpected size",
) )
...@@ -432,7 +413,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -432,7 +413,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
attn_probs = tf.where( attn_probs = tf.where(
masked_global_attn_index, masked_global_attn_index,
tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype), tf.zeros(get_shape_list(masked_global_attn_index), dtype=attn_probs.dtype),
attn_probs, attn_probs,
) )
...@@ -455,7 +436,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -455,7 +436,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an
overlap of size window_overlap overlap of size window_overlap
""" """
batch_size, seq_len, num_heads, head_dim = shape_list(query) batch_size, seq_len, num_heads, head_dim = get_shape_list(query)
if tf.executing_eagerly(): if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
...@@ -464,9 +445,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -464,9 +445,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}", message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
) )
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(query), get_shape_list(query),
shape_list(key), get_shape_list(key),
message=f"Shape of query and key should be equal, but got query: {shape_list(query)} and key: {shape_list(key)}", message=f"Shape of query and key should be equal, but got query: {get_shape_list(query)} and key: {get_shape_list(key)}",
) )
chunks_count = seq_len // window_overlap - 1 chunks_count = seq_len // window_overlap - 1
...@@ -574,7 +555,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -574,7 +555,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# pad to full matrix # pad to full matrix
padding = tf.convert_to_tensor( padding = tf.convert_to_tensor(
[[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]] [[0, get_shape_list(input_tensor)[1] - window_overlap], [0, get_shape_list(input_tensor)[3] - window_overlap - 1]]
) )
# create lower mask # create lower mask
...@@ -584,7 +565,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -584,7 +565,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1]) mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
# broadcast to full matrix # broadcast to full matrix
mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1)) mask_4d = tf.tile(mask_2d[None, :, None, :], (get_shape_list(input_tensor)[0], 1, 1, 1))
# inf tensor used for masking # inf tensor used for masking
inf_tensor = -float("inf") * tf.ones_like(input_tensor) inf_tensor = -float("inf") * tf.ones_like(input_tensor)
...@@ -600,7 +581,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -600,7 +581,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
same shape as `attn_probs` same shape as `attn_probs`
""" """
batch_size, seq_len, num_heads, head_dim = shape_list(value) batch_size, seq_len, num_heads, head_dim = get_shape_list(value)
if tf.executing_eagerly(): if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
...@@ -609,12 +590,12 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -609,12 +590,12 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
message="Seq_len has to be multiple of 2 * window_overlap", message="Seq_len has to be multiple of 2 * window_overlap",
) )
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_probs)[:3], get_shape_list(attn_probs)[:3],
shape_list(value)[:3], get_shape_list(value)[:3],
message="value and attn_probs must have same dims (except head_dim)", message="value and attn_probs must have same dims (except head_dim)",
) )
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_probs)[3], get_shape_list(attn_probs)[3],
2 * window_overlap + 1, 2 * window_overlap + 1,
message="attn_probs last dim has to be 2 * window_overlap + 1", message="attn_probs last dim has to be 2 * window_overlap + 1",
) )
...@@ -644,7 +625,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -644,7 +625,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
frame_size = 3 * window_overlap * head_dim frame_size = 3 * window_overlap * head_dim
frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count frame_hop_size = (get_shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count
chunked_value = tf.signal.frame( chunked_value = tf.signal.frame(
tf.reshape(padded_value, (batch_size * num_heads, -1)), tf.reshape(padded_value, (batch_size * num_heads, -1)),
frame_size, frame_size,
...@@ -657,7 +638,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -657,7 +638,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if tf.executing_eagerly(): if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(chunked_value), get_shape_list(chunked_value),
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim], [batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
message="Chunked value has the wrong shape", message="Chunked value has the wrong shape",
) )
...@@ -677,7 +658,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -677,7 +658,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
hidden_states_padded = tf.pad( hidden_states_padded = tf.pad(
hidden_states_padded, paddings hidden_states_padded, paddings
) # padding value is not important because it will be overwritten ) # padding value is not important because it will be overwritten
batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded) batch_size, chunk_size, seq_length, hidden_dim = get_shape_list(hidden_states_padded)
hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length)) hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length))
return hidden_states_padded return hidden_states_padded
...@@ -700,7 +681,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -700,7 +681,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000 0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
""" """
total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states) total_num_heads, num_chunks, window_overlap, hidden_dim = get_shape_list(chunked_hidden_states)
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]]) paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]])
chunked_hidden_states = tf.pad( chunked_hidden_states = tf.pad(
chunked_hidden_states, paddings chunked_hidden_states, paddings
...@@ -722,7 +703,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -722,7 +703,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
@staticmethod @staticmethod
def _chunk(hidden_states, window_overlap): def _chunk(hidden_states, window_overlap):
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w""" """convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
batch_size, seq_length, hidden_dim = shape_list(hidden_states) batch_size, seq_length, hidden_dim = get_shape_list(hidden_states)
num_output_chunks = 2 * (seq_length // (2 * window_overlap)) - 1 num_output_chunks = 2 * (seq_length // (2 * window_overlap)) - 1
# define frame size and frame stride (similar to convolution) # define frame size and frame stride (similar to convolution)
...@@ -735,9 +716,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -735,9 +716,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if tf.executing_eagerly(): if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(chunked_hidden_states), get_shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size], [batch_size, num_output_chunks, frame_size],
message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}.", message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size, frame_size, num_output_chunks]}, but got {get_shape_list(chunked_hidden_states)}.",
) )
chunked_hidden_states = tf.reshape( chunked_hidden_states = tf.reshape(
...@@ -752,7 +733,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -752,7 +733,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
"""compute global attn indices required throughout forward pass""" """compute global attn indices required throughout forward pass"""
# All global attention size are fixed through global_attention_size # All global attention size are fixed through global_attention_size
batch_size, seq_len = shape_list(is_index_global_attn) batch_size, seq_len = get_shape_list(is_index_global_attn)
max_num_global_attn_indices = global_attention_size max_num_global_attn_indices = global_attention_size
...@@ -787,7 +768,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -787,7 +768,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero, is_local_index_no_global_attn_nonzero,
): ):
batch_size = shape_list(key_vectors)[0] batch_size = get_shape_list(key_vectors)[0]
# select global key vectors # select global key vectors
global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero) global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero)
...@@ -809,8 +790,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -809,8 +790,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
# (batch_size, max_num_global_attn_indices, seq_len, num_heads) # (batch_size, max_num_global_attn_indices, seq_len, num_heads)
attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2)) attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2))
mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple( mask_shape = (get_shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
shape_list(attn_probs_from_global_key_trans)[-2:] get_shape_list(attn_probs_from_global_key_trans)[-2:]
) )
mask = tf.ones(mask_shape) * -10000.0 mask = tf.ones(mask_shape) * -10000.0
mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype) mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype)
...@@ -838,11 +819,11 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -838,11 +819,11 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
is_index_global_attn_nonzero, is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero,
): ):
batch_size = shape_list(attn_probs)[0] batch_size = get_shape_list(attn_probs)[0]
# cut local attn probs to global only # cut local attn probs to global only
attn_probs_only_global = attn_probs[:, :, :, :max_num_global_attn_indices] attn_probs_only_global = attn_probs[:, :, :, :max_num_global_attn_indices]
# attn_probs_only_global = tf.slice(attn_probs, [0, 0, 0, 0], shape_list(attn_probs)[: -1] + [max_num_global_attn_indices]) # attn_probs_only_global = tf.slice(attn_probs, [0, 0, 0, 0], get_shape_list(attn_probs)[: -1] + [max_num_global_attn_indices])
# select global value vectors # select global value vectors
global_value_vectors = tf.gather_nd(value_vectors, is_index_global_attn_nonzero) global_value_vectors = tf.gather_nd(value_vectors, is_index_global_attn_nonzero)
...@@ -863,7 +844,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -863,7 +844,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
attn_output_only_global = tf.einsum("blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global) attn_output_only_global = tf.einsum("blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global)
# reshape attn probs # reshape attn probs
attn_probs_without_global = attn_probs[:, :, :, max_num_global_attn_indices:] attn_probs_without_global = attn_probs[:, :, :, max_num_global_attn_indices:]
# attn_probs_without_global = tf.slice(attn_probs, [0, 0, 0, max_num_global_attn_indices], shape_list(attn_probs)[: -1] + [shape_list(attn_probs)[-1] - max_num_global_attn_indices]) # attn_probs_without_global = tf.slice(attn_probs, [0, 0, 0, max_num_global_attn_indices], get_shape_list(attn_probs)[: -1] + [get_shape_list(attn_probs)[-1] - max_num_global_attn_indices])
# compute attn output with global # compute attn output with global
attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value( attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
...@@ -884,7 +865,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -884,7 +865,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
is_index_masked, is_index_masked,
training, training,
): ):
batch_size, seq_len = shape_list(hidden_states)[:2] batch_size, seq_len = get_shape_list(hidden_states)[:2]
# prepare global hidden states # prepare global hidden states
global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero) global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero)
...@@ -912,9 +893,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -912,9 +893,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if tf.executing_eagerly(): if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(global_attn_scores), get_shape_list(global_attn_scores),
[batch_size * self._num_heads, max_num_global_attn_indices, seq_len], [batch_size * self._num_heads, max_num_global_attn_indices, seq_len],
message=f"global_attn_scores have the wrong size. Size should be {(batch_size * self._num_heads, max_num_global_attn_indices, seq_len)}, but is {shape_list(global_attn_scores)}.", message=f"global_attn_scores have the wrong size. Size should be {(batch_size * self._num_heads, max_num_global_attn_indices, seq_len)}, but is {get_shape_list(global_attn_scores)}.",
) )
global_attn_scores = tf.reshape( global_attn_scores = tf.reshape(
...@@ -922,8 +903,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -922,8 +903,8 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
(batch_size, self._num_heads, max_num_global_attn_indices, seq_len), (batch_size, self._num_heads, max_num_global_attn_indices, seq_len),
) )
global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3)) global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3))
mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple( mask_shape = (get_shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
shape_list(global_attn_scores_trans)[-2:] get_shape_list(global_attn_scores_trans)[-2:]
) )
global_attn_mask = tf.ones(mask_shape) * -10000.0 global_attn_mask = tf.ones(mask_shape) * -10000.0
global_attn_mask = tf.cast(global_attn_mask, dtype=global_attn_scores_trans.dtype) global_attn_mask = tf.cast(global_attn_mask, dtype=global_attn_scores_trans.dtype)
...@@ -937,7 +918,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -937,7 +918,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3)) global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
# mask global attn scores # mask global attn scores
attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1)) attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, get_shape_list(global_attn_scores)[1], 1, 1))
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores) global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
global_attn_scores = tf.reshape( global_attn_scores = tf.reshape(
global_attn_scores, global_attn_scores,
...@@ -951,9 +932,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -951,9 +932,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if layer_head_mask is not None: if layer_head_mask is not None:
if tf.executing_eagerly(): if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(layer_head_mask), get_shape_list(layer_head_mask),
[self._num_heads], [self._num_heads],
message=f"Head mask for a single layer should be of size {(self._num_heads)}, but is {shape_list(layer_head_mask)}", message=f"Head mask for a single layer should be of size {(self._num_heads)}, but is {get_shape_list(layer_head_mask)}",
) )
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
global_attn_probs_float, (batch_size, self._num_heads, max_num_global_attn_indices, seq_len) global_attn_probs_float, (batch_size, self._num_heads, max_num_global_attn_indices, seq_len)
...@@ -970,9 +951,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -970,9 +951,9 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
if tf.executing_eagerly(): if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(global_attn_output), get_shape_list(global_attn_output),
[batch_size * self._num_heads, max_num_global_attn_indices, self._key_dim], [batch_size * self._num_heads, max_num_global_attn_indices, self._key_dim],
message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self._num_heads, max_num_global_attn_indices, self._key_dim)}, but is {shape_list(global_attn_output)}.", message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self._num_heads, max_num_global_attn_indices, self._key_dim)}, but is {get_shape_list(global_attn_output)}.",
) )
global_attn_output = tf.reshape( global_attn_output = tf.reshape(
...@@ -987,7 +968,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -987,7 +968,7 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
) )
nonzero_global_attn_output = tf.reshape( nonzero_global_attn_output = tf.reshape(
nonzero_global_attn_output, nonzero_global_attn_output,
(shape_list(is_local_index_global_attn_nonzero)[0], -1), (get_shape_list(is_local_index_global_attn_nonzero)[0], -1),
) )
# overwrite values with global attention # overwrite values with global attention
......
...@@ -20,6 +20,7 @@ import tensorflow as tf ...@@ -20,6 +20,7 @@ import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.projects.longformer import longformer_attention from official.projects.longformer import longformer_attention
from official.modeling.tf_utils import get_shape_list
def _create_mock_attention_data( def _create_mock_attention_data(
...@@ -117,13 +118,13 @@ class LongformerAttentionTest(keras_parameterized.TestCase): ...@@ -117,13 +118,13 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
hidden_states = tf.reshape(hidden_states, (1, 8, 4)) # set seq length = 8, hidden dim = 4 hidden_states = tf.reshape(hidden_states, (1, 8, 4)) # set seq length = 8, hidden dim = 4
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2) chunked_hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2)
window_overlap_size = longformer_attention.shape_list(chunked_hidden_states)[2] window_overlap_size = get_shape_list(chunked_hidden_states)[2]
self.assertTrue(window_overlap_size == 4) self.assertTrue(window_overlap_size == 4)
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_diagonalize(chunked_hidden_states) padded_hidden_states = longformer_attention.LongformerAttention._pad_and_diagonalize(chunked_hidden_states)
self.assertTrue( self.assertTrue(
longformer_attention.shape_list(padded_hidden_states)[-1] == longformer_attention.shape_list(chunked_hidden_states)[-1] + window_overlap_size - 1 get_shape_list(padded_hidden_states)[-1] == get_shape_list(chunked_hidden_states)[-1] + window_overlap_size - 1
) )
# first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000] # first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000]
...@@ -138,14 +139,14 @@ class LongformerAttentionTest(keras_parameterized.TestCase): ...@@ -138,14 +139,14 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
def test_pad_and_transpose_last_two_dims(self): def test_pad_and_transpose_last_two_dims(self):
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
self.assertTrue(longformer_attention.shape_list(hidden_states), [1, 8, 4]) self.assertTrue(get_shape_list(hidden_states), [1, 8, 4])
# pad along seq length dim # pad along seq length dim
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2) hidden_states = longformer_attention.LongformerAttention._chunk(hidden_states, window_overlap=2)
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_transpose_last_two_dims(hidden_states, paddings) padded_hidden_states = longformer_attention.LongformerAttention._pad_and_transpose_last_two_dims(hidden_states, paddings)
self.assertTrue(longformer_attention.shape_list(padded_hidden_states) == [1, 1, 8, 5]) self.assertTrue(get_shape_list(padded_hidden_states) == [1, 1, 8, 5])
expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32) expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6) tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6)
...@@ -184,7 +185,7 @@ class LongformerAttentionTest(keras_parameterized.TestCase): ...@@ -184,7 +185,7 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
expected_slice_along_seq_length = tf.convert_to_tensor([0.4983, -0.7584, -1.6944], dtype=tf.dtypes.float32) expected_slice_along_seq_length = tf.convert_to_tensor([0.4983, -0.7584, -1.6944], dtype=tf.dtypes.float32)
expected_slice_along_chunk = tf.convert_to_tensor([0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.dtypes.float32) expected_slice_along_chunk = tf.convert_to_tensor([0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.dtypes.float32)
self.assertTrue(longformer_attention.shape_list(chunked_hidden_states) == [1, 3, 4, 4]) self.assertTrue(get_shape_list(chunked_hidden_states) == [1, 3, 4, 4])
tf.debugging.assert_near(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, rtol=1e-3) tf.debugging.assert_near(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, rtol=1e-3)
tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, rtol=1e-3) tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, rtol=1e-3)
......
...@@ -24,25 +24,8 @@ import tensorflow as tf ...@@ -24,25 +24,8 @@ import tensorflow as tf
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.projects.longformer.longformer_encoder_block import LongformerEncoderBlock from official.projects.longformer.longformer_encoder_block import LongformerEncoderBlock
from official.modeling.tf_utils import get_shape_list
def shape_list(tensor: tf.Tensor) -> List[int]:
"""
Deal with dynamic shape in tensorflow cleanly.
Args:
tensor (:obj:`tf.Tensor`): The tensor we want the shape of.
Returns:
:obj:`List[int]`: The shape of the tensor as a list.
"""
dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None):
return dynamic
static = tensor.shape.as_list()
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
_Initializer = Union[str, tf.keras.initializers.Initializer] _Initializer = Union[str, tf.keras.initializers.Initializer]
...@@ -262,7 +245,7 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -262,7 +245,7 @@ class LongformerEncoder(tf.keras.layers.Layer):
if self._embedding_projection is not None: if self._embedding_projection is not None:
embeddings = self._embedding_projection(embeddings) embeddings = self._embedding_projection(embeddings)
batch_size, seq_len = shape_list(mask) batch_size, seq_len = get_shape_list(mask)
# create masks with fixed len global_attention_size # create masks with fixed len global_attention_size
mask = tf.transpose(tf.concat(values=[tf.ones((self._global_attention_size, batch_size), tf.int32) * 2, mask = tf.transpose(tf.concat(values=[tf.ones((self._global_attention_size, batch_size), tf.int32) * 2,
tf.transpose(mask)[self._global_attention_size:]], axis=0)) tf.transpose(mask)[self._global_attention_size:]], axis=0))
...@@ -353,7 +336,7 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -353,7 +336,7 @@ class LongformerEncoder(tf.keras.layers.Layer):
assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}"
input_shape = shape_list(word_ids) if word_ids is not None else shape_list(word_embeddings) input_shape = get_shape_list(word_ids) if word_ids is not None else get_shape_list(word_embeddings)
batch_size, seq_len = input_shape[:2] batch_size, seq_len = input_shape[:2]
if seq_len is not None: if seq_len is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment