Unverified Commit c126a239 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix tflongformer int dtype (#18907)

* Use int64 throughout TFLongFormer

* make style

* Do some more fixed casting in TFLongFormer

* Fix some wonky "is None" conditionals

* Cast all the dtypes, salt the earth

* Fix copies to TFLED as well and do some casting there

* dtype fix in TFLongformer test

* Make fixup

* Expand tolerances on the LED tests too (I think this is a TF32 thing)

* Expand test tolerances for LED a tiny bit (probably a Tensorfloat thing again)
parent f7ceda34
......@@ -472,7 +472,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
)
first_chunk_mask = (
tf.tile(
tf.range(chunks_count + 1)[None, :, None, None],
tf.range(chunks_count + 1, dtype=tf.int64)[None, :, None, None],
(batch_size * num_heads, 1, window_overlap, window_overlap),
)
< 1
......@@ -1335,10 +1335,10 @@ class TFLEDPreTrainedModel(TFPreTrainedModel):
@property
def dummy_inputs(self):
input_ids = tf.convert_to_tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0]])
input_ids = tf.convert_to_tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0]], dtype=tf.int64)
# make sure global layers are initialized
attention_mask = tf.convert_to_tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0]])
global_attention_mask = tf.convert_to_tensor([[0, 0, 0, 0, 1], [0, 0, 1, 0, 0]])
attention_mask = tf.convert_to_tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0]], dtype=tf.int64)
global_attention_mask = tf.convert_to_tensor([[0, 0, 0, 0, 1], [0, 0, 1, 0, 0]], dtype=tf.int64)
dummy_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
......@@ -1350,10 +1350,10 @@ class TFLEDPreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int64, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int64, name="decoder_attention_mask"),
}
]
)
......
......@@ -395,11 +395,10 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
Computes global attention mask by putting attention on all tokens before `sep_token_id` if `before_sep_token is
True` else after `sep_token_id`.
"""
assert shape_list(sep_token_indices)[1] == 2, "`input_ids` should have two dimensions"
question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1][:, None]
# bool attention mask with True in locations of global attention
attention_mask = tf.expand_dims(tf.range(input_ids_shape[1]), axis=0)
attention_mask = tf.expand_dims(tf.range(input_ids_shape[1], dtype=tf.int64), axis=0)
attention_mask = tf.tile(attention_mask, (input_ids_shape[0], 1))
if before_sep_token is True:
question_end_index = tf.tile(question_end_index, (1, input_ids_shape[1]))
......@@ -468,10 +467,9 @@ class TFLongformerLMHead(tf.keras.layers.Layer):
return hidden_states
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaEmbeddings with Roberta->Longformer
class TFLongformerEmbeddings(tf.keras.layers.Layer):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing and some extra casting.
"""
def __init__(self, config, **kwargs):
......@@ -547,7 +545,7 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
input_shape = shape_list(inputs_embeds)[:-1]
if token_type_ids is None:
token_type_ids = tf.fill(dims=input_shape, value=0)
token_type_ids = tf.cast(tf.fill(dims=input_shape, value=0), tf.int64)
if position_ids is None:
if input_ids is not None:
......@@ -557,7 +555,8 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
)
else:
position_ids = tf.expand_dims(
tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0
tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1, dtype=tf.int64),
axis=0,
)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
......@@ -998,7 +997,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
first_chunk_mask = (
tf.tile(
tf.range(chunks_count + 1)[None, :, None, None],
tf.range(chunks_count + 1, dtype=tf.int64)[None, :, None, None],
(batch_size * num_heads, 1, window_overlap, window_overlap),
)
< 1
......@@ -1701,6 +1700,21 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
training=False,
):
if input_ids is not None and not isinstance(input_ids, tf.Tensor):
input_ids = tf.convert_to_tensor(input_ids, dtype=tf.int64)
elif input_ids is not None:
input_ids = tf.cast(input_ids, tf.int64)
if attention_mask is not None and not isinstance(attention_mask, tf.Tensor):
attention_mask = tf.convert_to_tensor(attention_mask, dtype=tf.int64)
elif attention_mask is not None:
attention_mask = tf.cast(attention_mask, tf.int64)
if global_attention_mask is not None and not isinstance(global_attention_mask, tf.Tensor):
global_attention_mask = tf.convert_to_tensor(global_attention_mask, dtype=tf.int64)
elif global_attention_mask is not None:
global_attention_mask = tf.cast(global_attention_mask, tf.int64)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
......@@ -1711,10 +1725,10 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
attention_mask = tf.cast(tf.fill(input_shape, 1), tf.int64)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
token_type_ids = tf.cast(tf.fill(input_shape, 0), tf.int64)
# merge `global_attention_mask` and `attention_mask`
if global_attention_mask is not None:
......@@ -1831,7 +1845,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
if inputs_embeds is not None:
def pad_embeddings():
input_ids_padding = tf.fill((batch_size, padding_len), self.pad_token_id)
input_ids_padding = tf.cast(tf.fill((batch_size, padding_len), self.pad_token_id), tf.int64)
inputs_embeds_padding = self.embeddings(input_ids_padding)
return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
......@@ -1875,10 +1889,15 @@ class TFLongformerPreTrainedModel(TFPreTrainedModel):
@property
def dummy_inputs(self):
input_ids = tf.convert_to_tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
input_ids = tf.convert_to_tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]], dtype=tf.int64)
# make sure global layers are initialized
attention_mask = tf.convert_to_tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
global_attention_mask = tf.convert_to_tensor([[0, 0, 0, 0, 1], [0, 0, 1, 0, 0], [0, 0, 0, 0, 1]])
attention_mask = tf.convert_to_tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]], dtype=tf.int64)
global_attention_mask = tf.convert_to_tensor(
[[0, 0, 0, 0, 1], [0, 0, 1, 0, 0], [0, 0, 0, 0, 1]], dtype=tf.int64
)
global_attention_mask = tf.convert_to_tensor(
[[0, 0, 0, 0, 1], [0, 0, 1, 0, 0], [0, 0, 0, 0, 1]], dtype=tf.int64
)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
......@@ -1888,8 +1907,8 @@ class TFLongformerPreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
}
]
)
......@@ -2235,6 +2254,21 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
are not taken into account for computing the loss.
"""
if input_ids is not None and not isinstance(input_ids, tf.Tensor):
input_ids = tf.convert_to_tensor(input_ids, dtype=tf.int64)
elif input_ids is not None:
input_ids = tf.cast(input_ids, tf.int64)
if attention_mask is not None and not isinstance(attention_mask, tf.Tensor):
attention_mask = tf.convert_to_tensor(attention_mask, dtype=tf.int64)
elif attention_mask is not None:
attention_mask = tf.cast(attention_mask, tf.int64)
if global_attention_mask is not None and not isinstance(global_attention_mask, tf.Tensor):
global_attention_mask = tf.convert_to_tensor(global_attention_mask, dtype=tf.int64)
elif global_attention_mask is not None:
global_attention_mask = tf.cast(global_attention_mask, tf.int64)
# set global attention on question tokens
if global_attention_mask is None and input_ids is not None:
if shape_list(tf.where(input_ids == self.config.sep_token_id))[0] != 3 * shape_list(input_ids)[0]:
......@@ -2244,12 +2278,12 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
" forward function to avoid this. This is most likely an error. The global attention is disabled"
" for this forward pass."
)
global_attention_mask = tf.fill(shape_list(input_ids), value=0)
global_attention_mask = tf.cast(tf.fill(shape_list(input_ids), value=0), tf.int64)
else:
logger.info("Initializing global attention on question tokens...")
# put global attention on all tokens until `config.sep_token_id` is reached
sep_token_indices = tf.where(input_ids == self.config.sep_token_id)
sep_token_indices = tf.cast(sep_token_indices, dtype=input_ids.dtype)
sep_token_indices = tf.cast(sep_token_indices, dtype=tf.int64)
global_attention_mask = _compute_global_attention_mask(shape_list(input_ids), sep_token_indices)
outputs = self.longformer(
......@@ -2375,13 +2409,28 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
training: Optional[bool] = False,
) -> Union[TFLongformerSequenceClassifierOutput, Tuple[tf.Tensor]]:
if input_ids is not None and not isinstance(input_ids, tf.Tensor):
input_ids = tf.convert_to_tensor(input_ids, dtype=tf.int64)
elif input_ids is not None:
input_ids = tf.cast(input_ids, tf.int64)
if attention_mask is not None and not isinstance(attention_mask, tf.Tensor):
attention_mask = tf.convert_to_tensor(attention_mask, dtype=tf.int64)
elif attention_mask is not None:
attention_mask = tf.cast(attention_mask, tf.int64)
if global_attention_mask is not None and not isinstance(global_attention_mask, tf.Tensor):
global_attention_mask = tf.convert_to_tensor(global_attention_mask, dtype=tf.int64)
elif global_attention_mask is not None:
global_attention_mask = tf.cast(global_attention_mask, tf.int64)
if global_attention_mask is None and input_ids is not None:
logger.info("Initializing global attention on CLS token...")
# global attention on cls token
global_attention_mask = tf.zeros_like(input_ids)
updates = tf.ones(shape_list(input_ids)[0], dtype=tf.int32)
updates = tf.ones(shape_list(input_ids)[0], dtype=tf.int64)
indices = tf.pad(
tensor=tf.expand_dims(tf.range(shape_list(input_ids)[0]), axis=1),
tensor=tf.expand_dims(tf.range(shape_list(input_ids)[0], dtype=tf.int64), axis=1),
paddings=[[0, 0], [0, 1]],
constant_values=0,
)
......@@ -2453,9 +2502,9 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
@property
def dummy_inputs(self):
input_ids = tf.convert_to_tensor(MULTIPLE_CHOICE_DUMMY_INPUTS)
input_ids = tf.convert_to_tensor(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int64)
# make sure global layers are initialized
global_attention_mask = tf.convert_to_tensor([[[0, 0, 0, 1], [0, 0, 0, 1]]] * 2)
global_attention_mask = tf.convert_to_tensor([[[0, 0, 0, 1], [0, 0, 0, 1]]] * 2, dtype=tf.int64)
return {"input_ids": input_ids, "global_attention_mask": global_attention_mask}
@unpack_inputs
......@@ -2547,8 +2596,8 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
"input_ids": tf.TensorSpec((None, None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int64, name="attention_mask"),
}
]
)
......
......@@ -412,7 +412,7 @@ class TFLEDModelIntegrationTest(unittest.TestCase):
expected_slice = tf.convert_to_tensor(
[[2.3050, 2.8279, 0.6531], [-1.8457, -0.1455, -3.5661], [-1.0186, 0.4586, -2.2043]],
)
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=TOLERANCE)
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-3)
def test_inference_with_head(self):
model = TFLEDForConditionalGeneration.from_pretrained("allenai/led-base-16384")
......@@ -428,4 +428,4 @@ class TFLEDModelIntegrationTest(unittest.TestCase):
expected_slice = tf.convert_to_tensor(
[[33.6507, 6.4572, 16.8089], [5.8739, -2.4238, 11.2902], [-3.2139, -4.3149, 4.2783]],
)
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=TOLERANCE)
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-3, rtol=1e-3)
......@@ -115,7 +115,7 @@ class TFLongformerModelTester:
):
model = TFLongformerModel(config=config)
attention_mask = tf.ones(input_ids.shape, dtype=tf.dtypes.int32)
attention_mask = tf.ones(input_ids.shape, dtype=tf.int64)
output_with_mask = model(input_ids, attention_mask=attention_mask)[0]
output_without_mask = model(input_ids)[0]
tf.debugging.assert_near(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], rtol=1e-4)
......@@ -403,26 +403,24 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
# first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000]
tf.debugging.assert_near(padded_hidden_states[0, 0, 0, :4], chunked_hidden_states[0, 0, 0], rtol=1e-3)
tf.debugging.assert_near(padded_hidden_states[0, 0, 0, 4:], tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3)
tf.debugging.assert_near(padded_hidden_states[0, 0, 0, 4:], tf.zeros((3,), dtype=tf.float32), rtol=1e-3)
# last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629]
tf.debugging.assert_near(padded_hidden_states[0, 0, -1, 3:], chunked_hidden_states[0, 0, -1], rtol=1e-3)
tf.debugging.assert_near(
padded_hidden_states[0, 0, -1, :3], tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3
)
tf.debugging.assert_near(padded_hidden_states[0, 0, -1, :3], tf.zeros((3,), dtype=tf.float32), rtol=1e-3)
def test_pad_and_transpose_last_two_dims(self):
hidden_states = self._get_hidden_states()
self.assertEqual(shape_list(hidden_states), [1, 4, 8])
# pad along seq length dim
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.int64)
hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2)
padded_hidden_states = TFLongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, paddings)
self.assertTrue(shape_list(padded_hidden_states) == [1, 1, 8, 5])
expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
expected_added_dim = tf.zeros((5,), dtype=tf.float32)
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6)
tf.debugging.assert_near(
hidden_states[0, 0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
......@@ -441,10 +439,10 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
hid_states_3 = TFLongformerSelfAttention._mask_invalid_locations(hidden_states[:, :, :, :3], 2)
hid_states_4 = TFLongformerSelfAttention._mask_invalid_locations(hidden_states[:, :, 2:, :], 2)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_1), tf.dtypes.int32)) == 8)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_2), tf.dtypes.int32)) == 24)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_3), tf.dtypes.int32)) == 24)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_4), tf.dtypes.int32)) == 12)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_1), tf.int64)) == 8)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_2), tf.int64)) == 24)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_3), tf.int64)) == 24)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_4), tf.int64)) == 12)
def test_chunk(self):
hidden_states = self._get_hidden_states()
......@@ -456,12 +454,14 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
chunked_hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2)
# expected slices across chunk and seq length dim
expected_slice_along_seq_length = tf.convert_to_tensor([0.4983, -0.7584, -1.6944], dtype=tf.dtypes.float32)
expected_slice_along_chunk = tf.convert_to_tensor([0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.dtypes.float32)
expected_slice_along_seq_length = tf.convert_to_tensor([0.4983, -0.7584, -1.6944], dtype=tf.float32)
expected_slice_along_chunk = tf.convert_to_tensor([0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.float32)
self.assertTrue(shape_list(chunked_hidden_states) == [1, 3, 4, 4])
tf.debugging.assert_near(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, rtol=1e-3)
tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, rtol=1e-3)
tf.debugging.assert_near(
chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, rtol=1e-3, atol=1e-4
)
tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, rtol=1e-3, atol=1e-4)
def test_layer_local_attn(self):
model = TFLongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny")
......@@ -469,7 +469,7 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
hidden_states = self._get_hidden_states()
batch_size, seq_length, hidden_size = hidden_states.shape
attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32)
attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.float32)
is_index_global_attn = tf.math.greater(attention_mask, 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
......@@ -483,11 +483,11 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
)[0]
expected_slice = tf.convert_to_tensor(
[0.00188, 0.012196, -0.017051, -0.025571, -0.02996, 0.017297, -0.011521, 0.004848], dtype=tf.dtypes.float32
[0.00188, 0.012196, -0.017051, -0.025571, -0.02996, 0.017297, -0.011521, 0.004848], dtype=tf.float32
)
self.assertEqual(output_hidden_states.shape, (1, 4, 8))
tf.debugging.assert_near(output_hidden_states[0, 1], expected_slice, rtol=1e-3)
tf.debugging.assert_near(output_hidden_states[0, 1], expected_slice, rtol=1e-3, atol=1e-4)
def test_layer_global_attn(self):
model = TFLongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny")
......@@ -498,8 +498,8 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
batch_size, seq_length, hidden_size = hidden_states.shape
# create attn mask
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.float32)
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.float32)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 1, 10000.0, attention_mask_1)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1)
......@@ -525,15 +525,15 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
self.assertEqual(output_hidden_states.shape, (2, 4, 8))
expected_slice_0 = tf.convert_to_tensor(
[-0.06508, -0.039306, 0.030934, -0.03417, -0.00656, -0.01553, -0.02088, -0.04938], dtype=tf.dtypes.float32
[-0.06508, -0.039306, 0.030934, -0.03417, -0.00656, -0.01553, -0.02088, -0.04938], dtype=tf.float32
)
expected_slice_1 = tf.convert_to_tensor(
[-0.04055, -0.038399, 0.0396, -0.03735, -0.03415, 0.01357, 0.00145, -0.05709], dtype=tf.dtypes.float32
[-0.04055, -0.038399, 0.0396, -0.03735, -0.03415, 0.01357, 0.00145, -0.05709], dtype=tf.float32
)
tf.debugging.assert_near(output_hidden_states[0, 2], expected_slice_0, rtol=1e-3)
tf.debugging.assert_near(output_hidden_states[1, -2], expected_slice_1, rtol=1e-3)
tf.debugging.assert_near(output_hidden_states[0, 2], expected_slice_0, rtol=1e-3, atol=1e-4)
tf.debugging.assert_near(output_hidden_states[1, -2], expected_slice_1, rtol=1e-3, atol=1e-4)
def test_layer_attn_probs(self):
model = TFLongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny")
......@@ -542,8 +542,8 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
batch_size, seq_length, hidden_size = hidden_states.shape
# create attn mask
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.float32)
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.float32)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 1, 10000.0, attention_mask_1)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1)
......@@ -584,18 +584,16 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
tf.debugging.assert_near(
local_attentions[0, 0, 0, :],
tf.convert_to_tensor(
[0.3328, 0.0000, 0.0000, 0.0000, 0.0000, 0.3355, 0.3318, 0.0000], dtype=tf.dtypes.float32
),
tf.convert_to_tensor([0.3328, 0.0000, 0.0000, 0.0000, 0.0000, 0.3355, 0.3318, 0.0000], dtype=tf.float32),
rtol=1e-3,
atol=1e-4,
)
tf.debugging.assert_near(
local_attentions[1, 0, 0, :],
tf.convert_to_tensor(
[0.2492, 0.2502, 0.2502, 0.0000, 0.0000, 0.2505, 0.0000, 0.0000], dtype=tf.dtypes.float32
),
tf.convert_to_tensor([0.2492, 0.2502, 0.2502, 0.0000, 0.0000, 0.2505, 0.0000, 0.0000], dtype=tf.float32),
rtol=1e-3,
atol=1e-4,
)
# All the global attention weights must sum to 1.
......@@ -603,13 +601,15 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
tf.debugging.assert_near(
global_attentions[0, 0, 1, :],
tf.convert_to_tensor([0.2500, 0.2500, 0.2500, 0.2500], dtype=tf.dtypes.float32),
tf.convert_to_tensor([0.2500, 0.2500, 0.2500, 0.2500], dtype=tf.float32),
rtol=1e-3,
atol=1e-4,
)
tf.debugging.assert_near(
global_attentions[1, 0, 0, :],
tf.convert_to_tensor([0.2497, 0.2500, 0.2499, 0.2504], dtype=tf.dtypes.float32),
tf.convert_to_tensor([0.2497, 0.2500, 0.2499, 0.2504], dtype=tf.float32),
rtol=1e-3,
atol=1e-4,
)
@slow
......@@ -617,31 +617,31 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
model = TFLongformerModel.from_pretrained("allenai/longformer-base-4096")
# 'Hello world!'
input_ids = tf.convert_to_tensor([[0, 20920, 232, 328, 1437, 2]], dtype=tf.dtypes.int32)
attention_mask = tf.ones(shape_list(input_ids), dtype=tf.dtypes.int32)
input_ids = tf.convert_to_tensor([[0, 20920, 232, 328, 1437, 2]], dtype=tf.int64)
attention_mask = tf.ones(shape_list(input_ids), dtype=tf.int64)
output = model(input_ids, attention_mask=attention_mask)[0]
output_without_mask = model(input_ids)[0]
expected_output_slice = tf.convert_to_tensor(
[0.0549, 0.1087, -0.1119, -0.0368, 0.0250], dtype=tf.dtypes.float32
)
expected_output_slice = tf.convert_to_tensor([0.0549, 0.1087, -0.1119, -0.0368, 0.0250], dtype=tf.float32)
tf.debugging.assert_near(output[0, 0, -5:], expected_output_slice, rtol=1e-3)
tf.debugging.assert_near(output_without_mask[0, 0, -5:], expected_output_slice, rtol=1e-3)
tf.debugging.assert_near(output[0, 0, -5:], expected_output_slice, rtol=1e-3, atol=1e-4)
tf.debugging.assert_near(output_without_mask[0, 0, -5:], expected_output_slice, rtol=1e-3, atol=1e-4)
@slow
def test_inference_no_head_long(self):
model = TFLongformerModel.from_pretrained("allenai/longformer-base-4096")
# 'Hello world! ' repeated 1000 times
input_ids = tf.convert_to_tensor([[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=tf.dtypes.int32)
input_ids = tf.convert_to_tensor([[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=tf.int64)
attention_mask = tf.ones(shape_list(input_ids), dtype=tf.dtypes.int32)
global_attention_mask = tf.zeros(shape_list(input_ids), dtype=tf.dtypes.int32)
attention_mask = tf.ones(shape_list(input_ids), dtype=tf.int64)
global_attention_mask = tf.zeros(shape_list(input_ids), dtype=tf.int64)
# Set global attention on a few random positions
global_attention_mask = tf.tensor_scatter_nd_update(
global_attention_mask, tf.constant([[0, 1], [0, 4], [0, 21]]), tf.constant([1, 1, 1])
global_attention_mask,
tf.constant([[0, 1], [0, 4], [0, 21]], dtype=tf.int64),
tf.constant([1, 1, 1], dtype=tf.int64),
)
output = model(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)[0]
......@@ -650,15 +650,15 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
expected_output_mean = tf.constant(0.024267)
# assert close
tf.debugging.assert_near(tf.reduce_sum(output), expected_output_sum, rtol=1e-4)
tf.debugging.assert_near(tf.reduce_mean(output), expected_output_mean, rtol=1e-4)
tf.debugging.assert_near(tf.reduce_sum(output), expected_output_sum, rtol=1e-4, atol=1e-4)
tf.debugging.assert_near(tf.reduce_mean(output), expected_output_mean, rtol=1e-4, atol=1e-4)
@slow
def test_inference_masked_lm_long(self):
model = TFLongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
# 'Hello world! ' repeated 1000 times
input_ids = tf.convert_to_tensor([[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=tf.dtypes.int32)
input_ids = tf.convert_to_tensor([[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=tf.int64)
output = model(input_ids, labels=input_ids)
loss = output.loss
......@@ -669,9 +669,13 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
expected_prediction_scores_mean = tf.constant(-3.03477)
# assert close
tf.debugging.assert_near(tf.reduce_mean(loss), expected_loss, rtol=1e-4)
tf.debugging.assert_near(tf.reduce_sum(prediction_scores), expected_prediction_scores_sum, rtol=1e-4)
tf.debugging.assert_near(tf.reduce_mean(prediction_scores), expected_prediction_scores_mean, rtol=1e-4)
tf.debugging.assert_near(tf.reduce_mean(loss), expected_loss, rtol=1e-4, atol=1e-4)
tf.debugging.assert_near(
tf.reduce_sum(prediction_scores), expected_prediction_scores_sum, rtol=1e-4, atol=1e-4
)
tf.debugging.assert_near(
tf.reduce_mean(prediction_scores), expected_prediction_scores_mean, rtol=1e-4, atol=1e-4
)
@slow
def test_inference_masked_lm(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment