Unverified Commit 00bb0b25 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

TF Longformer (#5764)



* improve names and tests longformer

* more and better tests for longformer

* add first tf test

* finalize tf basic op functions

* fix merge

* tf shape test passes

* narrow down discrepancies

* make longformer local attn tf work

* correct tf longformer

* add first global attn function

* add more global longformer func

* advance tf longformer

* finish global attn

* upload big model

* finish all tests

* correct false any statement

* fix common tests

* make all tests pass except keras save load

* fix some tests

* fix torch test import

* finish tests

* fix test

* fix torch tf tests

* add docs

* finish docs

* Update src/transformers/modeling_longformer.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Update src/transformers/modeling_tf_longformer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* apply Lysandres suggestions

* reverse to assert statement because function will fail otherwise

* applying sylvains recommendations

* Update src/transformers/modeling_longformer.py
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>

* Update src/transformers/modeling_tf_longformer.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent 34259366
......@@ -102,3 +102,25 @@ LongformerForQuestionAnswering
.. autoclass:: transformers.LongformerForQuestionAnswering
:members:
TFLongformerModel
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLongformerModel
:members:
TFLongformerForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLongformerForMaskedLM
:members:
TFLongformerForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLongformerForQuestionAnswering
:members:
......@@ -399,6 +399,7 @@ if is_torch_available():
LongformerForMultipleChoice,
LongformerForTokenClassification,
LongformerForQuestionAnswering,
LongformerSelfAttention,
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
)
......@@ -568,6 +569,14 @@ if is_tf_available():
TFGPT2PreTrainedModel,
)
from .modeling_tf_longformer import (
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLongformerModel,
TFLongformerForMaskedLM,
TFLongformerForQuestionAnswering,
TFLongformerSelfAttention,
)
from .modeling_tf_mobilebert import (
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMobileBertModel,
......
This diff is collapsed.
......@@ -29,6 +29,7 @@ from .configuration_auto import (
ElectraConfig,
FlaubertConfig,
GPT2Config,
LongformerConfig,
MobileBertConfig,
OpenAIGPTConfig,
RobertaConfig,
......@@ -93,6 +94,7 @@ from .modeling_tf_flaubert import (
TFFlaubertWithLMHeadModel,
)
from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
from .modeling_tf_longformer import TFLongformerForMaskedLM, TFLongformerForQuestionAnswering, TFLongformerModel
from .modeling_tf_mobilebert import (
TFMobileBertForMaskedLM,
TFMobileBertForMultipleChoice,
......@@ -149,6 +151,7 @@ TF_MODEL_MAPPING = OrderedDict(
(AlbertConfig, TFAlbertModel),
(CamembertConfig, TFCamembertModel),
(XLMRobertaConfig, TFXLMRobertaModel),
(LongformerConfig, TFLongformerModel),
(RobertaConfig, TFRobertaModel),
(BertConfig, TFBertModel),
(OpenAIGPTConfig, TFOpenAIGPTModel),
......@@ -191,6 +194,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(AlbertConfig, TFAlbertForMaskedLM),
(CamembertConfig, TFCamembertForMaskedLM),
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
(LongformerConfig, TFLongformerForMaskedLM),
(RobertaConfig, TFRobertaForMaskedLM),
(BertConfig, TFBertForMaskedLM),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
......@@ -226,6 +230,7 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
(AlbertConfig, TFAlbertForMaskedLM),
(CamembertConfig, TFCamembertForMaskedLM),
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
(LongformerConfig, TFLongformerForMaskedLM),
(RobertaConfig, TFRobertaForMaskedLM),
(BertConfig, TFBertForMaskedLM),
(MobileBertConfig, TFMobileBertForMaskedLM),
......@@ -259,6 +264,7 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
(AlbertConfig, TFAlbertForQuestionAnswering),
(CamembertConfig, TFCamembertForQuestionAnswering),
(XLMRobertaConfig, TFXLMRobertaForQuestionAnswering),
(LongformerConfig, TFLongformerForQuestionAnswering),
(RobertaConfig, TFRobertaForQuestionAnswering),
(BertConfig, TFBertForQuestionAnswering),
(XLNetConfig, TFXLNetForQuestionAnsweringSimple),
......
This diff is collapsed.
......@@ -33,6 +33,7 @@ if is_torch_available():
LongformerForTokenClassification,
LongformerForQuestionAnswering,
LongformerForMultipleChoice,
LongformerSelfAttention,
)
......@@ -325,7 +326,209 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
self.model_tester.create_and_check_longformer_for_multiple_choice(*config_and_inputs)
@require_torch
class LongformerModelIntegrationTest(unittest.TestCase):
def _get_hidden_states(self):
return torch.tensor(
[
[
[
4.98332758e-01,
2.69175139e00,
-7.08081422e-03,
1.04915401e00,
-1.83476661e00,
7.67220476e-01,
2.98580543e-01,
2.84803992e-02,
],
[
-7.58357372e-01,
4.20635998e-01,
-4.04739919e-02,
1.59924145e-01,
2.05135748e00,
-1.15997978e00,
5.37166397e-01,
2.62873606e-01,
],
[
-1.69438001e00,
4.17574660e-01,
-1.49196962e00,
-1.76483717e00,
-1.94566312e-01,
-1.71183858e00,
7.72903565e-01,
-1.11557056e00,
],
[
5.44028163e-01,
2.05466114e-01,
-3.63045868e-01,
2.41865062e-01,
3.20348382e-01,
-9.05611176e-01,
-1.92690727e-01,
-1.19917547e00,
],
]
],
dtype=torch.float32,
device=torch_device,
)
def test_diagonalize(self):
hidden_states = self._get_hidden_states()
hidden_states = hidden_states.reshape((1, 8, 4)) # set seq length = 8, hidden dim = 4
chunked_hidden_states = LongformerSelfAttention._chunk(hidden_states, window_overlap=2)
window_overlap_size = chunked_hidden_states.shape[2]
self.assertTrue(window_overlap_size == 4)
padded_hidden_states = LongformerSelfAttention._pad_and_diagonalize(chunked_hidden_states)
self.assertTrue(padded_hidden_states.shape[-1] == chunked_hidden_states.shape[-1] + window_overlap_size - 1)
# first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000]
self.assertTrue(torch.allclose(padded_hidden_states[0, 0, 0, :4], chunked_hidden_states[0, 0, 0], atol=1e-3))
self.assertTrue(
torch.allclose(
padded_hidden_states[0, 0, 0, 4:],
torch.zeros((3,), device=torch_device, dtype=torch.float32),
atol=1e-3,
)
)
# last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629]
self.assertTrue(torch.allclose(padded_hidden_states[0, 0, -1, 3:], chunked_hidden_states[0, 0, -1], atol=1e-3))
self.assertTrue(
torch.allclose(
padded_hidden_states[0, 0, -1, :3],
torch.zeros((3,), device=torch_device, dtype=torch.float32),
atol=1e-3,
)
)
def test_pad_and_transpose_last_two_dims(self):
hidden_states = self._get_hidden_states()
self.assertTrue(hidden_states.shape, (1, 8, 4))
padding = (0, 0, 0, 1)
padded_hidden_states = LongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, padding)
self.assertTrue(padded_hidden_states.shape, (1, 8, 5))
expected_added_dim = torch.zeros((5,), device=torch_device, dtype=torch.float32)
self.assertTrue(torch.allclose(expected_added_dim, padded_hidden_states[0, -1, :], atol=1e-6))
self.assertTrue(torch.allclose(hidden_states[0, -1, :], padded_hidden_states.view(1, -1)[0, 24:32], atol=1e-6))
def test_chunk(self):
hidden_states = self._get_hidden_states()
batch_size = 1
seq_length = 8
hidden_size = 4
hidden_states = hidden_states.reshape((batch_size, seq_length, hidden_size))
chunked_hidden_states = LongformerSelfAttention._chunk(hidden_states, window_overlap=2)
# expected slices across chunk and seq length dim
expected_slice_along_seq_length = torch.tensor(
[0.4983, -0.7584, -1.6944], device=torch_device, dtype=torch.float32
)
expected_slice_along_chunk = torch.tensor(
[0.4983, -1.8348, -0.7584, 2.0514], device=torch_device, dtype=torch.float32
)
self.assertTrue(torch.allclose(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, atol=1e-3))
self.assertTrue(torch.allclose(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, atol=1e-3))
self.assertTrue(chunked_hidden_states.shape, (1, 3, 4, 4))
def test_mask_invalid_locations(self):
hidden_states = self._get_hidden_states()
batch_size = 1
seq_length = 8
hidden_size = 4
hidden_states = hidden_states.reshape((batch_size, seq_length, hidden_size))
chunked_hidden_states = LongformerSelfAttention._chunk(hidden_states, window_overlap=2)
hid_states_1 = chunked_hidden_states.clone()
LongformerSelfAttention._mask_invalid_locations(hid_states_1, 1)
self.assertTrue(torch.isinf(hid_states_1).sum().item() == 8)
hid_states_2 = chunked_hidden_states.clone()
LongformerSelfAttention._mask_invalid_locations(hid_states_2, 2)
self.assertTrue(torch.isinf(hid_states_2).sum().item() == 24)
hid_states_3 = chunked_hidden_states.clone()[:, :, :, :3]
LongformerSelfAttention._mask_invalid_locations(hid_states_3, 2)
self.assertTrue(torch.isinf(hid_states_3).sum().item() == 24)
hid_states_4 = chunked_hidden_states.clone()[:, :, 2:, :]
LongformerSelfAttention._mask_invalid_locations(hid_states_4, 2)
self.assertTrue(torch.isinf(hid_states_4).sum().item() == 12)
def test_layer_local_attn(self):
model = LongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny")
model.eval()
layer = model.encoder.layer[0].attention.self.to(torch_device)
hidden_states = self._get_hidden_states()
batch_size, seq_length, hidden_size = hidden_states.size()
attention_mask = torch.zeros((batch_size, 1, 1, seq_length), dtype=torch.float32, device=torch_device)
attention_mask[:, :, :, -2:] = -10000
output_hidden_states = layer(hidden_states, attention_mask)[0]
self.assertTrue(output_hidden_states.shape, (1, 4, 8))
self.assertTrue(
torch.allclose(
output_hidden_states[0, 1],
torch.tensor(
[0.0019, 0.0122, -0.0171, -0.0256, -0.0300, 0.0173, -0.0115, 0.0048],
dtype=torch.float32,
device=torch_device,
),
atol=1e-3,
)
)
def test_layer_global_attn(self):
model = LongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny")
model.eval()
layer = model.encoder.layer[0].attention.self.to(torch_device)
hidden_states = torch.cat([self._get_hidden_states(), self._get_hidden_states() - 0.5], dim=0)
batch_size, seq_length, hidden_size = hidden_states.size()
attention_mask = torch.zeros((batch_size, 1, 1, seq_length), dtype=torch.float32, device=torch_device)
# create attn mask
attention_mask[0, :, :, -2:] = 10000.0
attention_mask[0, :, :, -1:] = -10000.0
attention_mask[1, :, :, 1:] = 10000.0
output_hidden_states = layer(hidden_states, attention_mask)[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
self.assertTrue(
torch.allclose(
output_hidden_states[0, 2],
torch.tensor(
[-0.0651, -0.0393, 0.0309, -0.0342, -0.0066, -0.0155, -0.0209, -0.0494],
dtype=torch.float32,
device=torch_device,
),
atol=1e-3,
)
)
self.assertTrue(
torch.allclose(
output_hidden_states[1, -2],
torch.tensor(
[-0.0405, -0.0384, 0.0396, -0.0374, -0.0341, 0.0136, 0.0014, -0.0571],
dtype=torch.float32,
device=torch_device,
),
atol=1e-3,
)
)
@slow
def test_inference_no_head(self):
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
......@@ -371,13 +574,13 @@ class LongformerModelIntegrationTest(unittest.TestCase):
input_ids = torch.tensor(
[[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=torch.long, device=torch_device
) # long input
input_ids = input_ids.to(torch_device)
loss, prediction_scores = model(input_ids, labels=input_ids)
expected_loss = torch.tensor(0.0074, device=torch_device)
expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device)
expected_prediction_scores_mean = torch.tensor(-3.0348, device=torch_device)
input_ids = input_ids.to(torch_device)
self.assertTrue(torch.allclose(loss, expected_loss, atol=1e-4))
self.assertTrue(torch.allclose(prediction_scores.sum(), expected_prediction_scores_sum, atol=1e-4))
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment