Unverified Commit 56ee2560 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Longformer] Better handling of global attention mask vs local attention mask (#4672)

* better api

* improve automatic setting of global attention mask

* fix longformer bug

* fix global attention mask in test

* fix global attn mask flatten

* fix slow tests

* update docstring

* update docs and make more robust

* improve attention mask
parent e2230ba7
......@@ -21,7 +21,7 @@ A selecetd few tokens attend "globally" to all other tokens, as it is convention
Note that "locally" and "globally" attending tokens are projected by different query, key and value matrices.
Also note that every "locally" attending token not only attends to tokens within its window :math:`w`, but also to all "globally" attending tokens so that global attention is *symmetric*.
The user can define which tokens are masked, which tokens attend "locally" and which tokens attend "globally" by setting the `config.attention_mask` `torch.Tensor` appropriately. In contrast to other models `Longformer` accepts the following values in `config.attention_mask`: `0` - the token is masked and not attended at all (as is done in other models), `1` - the token attends "locally", `2` - token attends "globally". For more information please also refer to :func:`~transformers.LongformerModel.forward` method.
The user can define which tokens attend "locally" and which tokens attend "globally" by setting the tensor `global_attention_mask` at run-time appropriately. `Longformer` employs the following logic for `global_attention_mask`: `0` - the token attends "locally", `1` - token attends "globally". For more information please also refer to :func:`~transformers.LongformerModel.forward` method.
Using Longformer self attention, the memory and time complexity of the query-key matmul operation, which usually represents the memory and time bottleneck, can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times w)`, with :math:`n_s` being the sequence length and :math:`w` being the average window size. It is assumed that the number of "globally" attending tokens is insignificant as compared to the number of "locally" attending tokens.
......
......@@ -39,6 +39,44 @@ LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {
}
def _get_question_end_index(input_ids, sep_token_id):
"""
Computes the index of the first occurance of `sep_token_id`.
"""
sep_token_indices = (input_ids == sep_token_id).nonzero()
batch_size = input_ids.shape[0]
assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
assert (
sep_token_indices.shape[0] == 3 * batch_size
), f"There should be exactly three separator tokens: {sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this error."
return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]
def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=True):
"""
Computes global attention mask by putting attention on all tokens
before `sep_token_id` if `before_sep_token is True` else after
`sep_token_id`.
"""
question_end_index = _get_question_end_index(input_ids, sep_token_id)
question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1
# bool attention mask with True in locations of global attention
attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
if before_sep_token is True:
attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.uint8)
else:
# last token is separation token and should not be counted and in the middle are two separation tokens
attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.uint8) * (
attention_mask.expand_as(input_ids) < input_ids.shape[-1]
).to(torch.uint8)
return attention_mask
class LongformerSelfAttention(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
......@@ -420,17 +458,22 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Mask to decide the attention given on each token, local attention, global attenion, or no attention (for padding tokens).
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
`What are attention masks? <../glossary.html#attention-mask>`__
global_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Mask to decide the attention given on each token, local attention or global attenion.
Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for
task-specific finetuning because it makes the model more flexible at representing the task. For example,
for classification, the <s> token should be given global attention. For QA, all question tokens should also have
global attention. Please refer to the Longformer paper https://arxiv.org/abs/2004.05150 for more details.
Mask values selected in ``[0, 1, 2]``:
``0`` for no attention (padding tokens),
``1`` for local attention (a sliding window attention),
``2`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
Mask values selected in ``[0, 1]``:
``0`` for local attention (a sliding window attention),
``1`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
......@@ -542,6 +585,7 @@ class LongformerModel(RobertaModel):
self,
input_ids=None,
attention_mask=None,
global_attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
......@@ -593,6 +637,19 @@ class LongformerModel(RobertaModel):
if isinstance(self.config.attention_window, int)
else max(self.config.attention_window)
)
# merge `global_attention_mask` and `attention_mask`
if global_attention_mask is not None:
# longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
# (global_attention_mask + 1) => 1 for local attention, 2 for global attention
# => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention
if attention_mask is not None:
attention_mask = attention_mask * (global_attention_mask + 1)
else:
# simply use `global_attention_mask` as `attention_mask`
# if no `attention_mask` is given
attention_mask = global_attention_mask + 1
padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size(
input_ids=input_ids,
attention_mask=attention_mask,
......@@ -646,6 +703,7 @@ class LongformerForMaskedLM(BertPreTrainedModel):
self,
input_ids=None,
attention_mask=None,
global_attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
......@@ -695,6 +753,7 @@ class LongformerForMaskedLM(BertPreTrainedModel):
outputs = self.longformer(
input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
......@@ -734,6 +793,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
self,
input_ids=None,
attention_mask=None,
global_attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
......@@ -778,15 +838,16 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
"""
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
# global attention on cls token
attention_mask[:, 0] = 2
if global_attention_mask is None:
logger.info("Initializing global attention on CLS token...")
global_attention_mask = torch.zeros_like(input_ids)
# global attention on cls token
global_attention_mask[:, 0] = 1
outputs = self.longformer(
input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
......@@ -846,31 +907,12 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
self.init_weights()
def _compute_global_attention_mask(self, input_ids):
question_end_index = self._get_question_end_index(input_ids)
question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1
# bool attention mask with True in locations of global attention
attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
attention_mask = attention_mask.expand_as(input_ids) < question_end_index
return attention_mask.long() + 1 # True => global attention; False => local attention
def _get_question_end_index(self, input_ids):
sep_token_indices = (input_ids == self.config.sep_token_id).nonzero()
batch_size = input_ids.shape[0]
assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
assert (
sep_token_indices.shape[0] == 3 * batch_size
), f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering"
return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids,
attention_mask=None,
global_attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
......@@ -929,17 +971,15 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
"""
# set global attention on question tokens
global_attention_mask = self._compute_global_attention_mask(input_ids)
if attention_mask is None:
attention_mask = global_attention_mask
else:
# combine global_attention_mask with attention_mask
# global attention on question tokens, no attention on padding tokens
attention_mask = global_attention_mask * attention_mask
if global_attention_mask is None:
logger.info("Initializing global attention on question tokens...")
# put global attention on all tokens until `config.sep_token_id` is reached
global_attention_mask = _compute_global_attention_mask(input_ids, self.config.sep_token_id)
outputs = self.longformer(
input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
......@@ -998,6 +1038,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
self,
input_ids=None,
attention_mask=None,
global_attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
......@@ -1043,6 +1084,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
outputs = self.longformer(
input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
......@@ -1097,6 +1139,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
input_ids=None,
token_type_ids=None,
attention_mask=None,
global_attention_mask=None,
labels=None,
position_ids=None,
inputs_embeds=None,
......@@ -1129,29 +1172,51 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
Examples::
from transformers import LongformerTokenizer, LongformerForTokenClassification
from transformers import LongformerTokenizer, LongformerForMultipleChoice
import torch
tokenizer = LongformerTokenizer.from_pretrained('longformer-base-4096')
model = LongformerForMultipleChoice.from_pretrained('longformer-base-4096')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
model = LongformerForMultipleChoice.from_pretrained('allenai/longformer-base-4096')
# context = "The dog is cute" | choice = "the dog" / "the cat"
choices = [("The dog is cute", "the dog"), ("The dog is cute", "the cat")]
input_ids = torch.tensor([tokenizer.encode(s[0], s[1], add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
# global attention is automatically put on "the dog" and "the cat"
outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2]
"""
num_choices = input_ids.shape[1]
# set global attention on question tokens
if global_attention_mask is None:
logger.info("Initializing global attention on multiple choice...")
# put global attention on all tokens after `config.sep_token_id`
global_attention_mask = torch.stack(
[
_compute_global_attention_mask(input_ids[:, i], self.config.sep_token_id, before_sep_token=False)
for i in range(num_choices)
],
dim=1,
)
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
flat_global_attention_mask = (
global_attention_mask.view(-1, global_attention_mask.size(-1))
if global_attention_mask is not None
else None
)
outputs = self.longformer(
flat_input_ids,
position_ids=flat_position_ids,
token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask,
global_attention_mask=flat_global_attention_mask,
)
pooled_output = outputs[1]
......
......@@ -184,6 +184,7 @@ class LongformerModelTester(object):
loss, start_logits, end_logits = model(
input_ids,
attention_mask=input_mask,
global_attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
......@@ -239,9 +240,11 @@ class LongformerModelTester(object):
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
global_attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
)
......@@ -330,7 +333,7 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
class LongformerModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head(self):
model = LongformerModel.from_pretrained("longformer-base-4096")
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
model.to(torch_device)
# 'Hello world! ' repeated 1000 times
......@@ -350,7 +353,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_masked_lm(self):
model = LongformerForMaskedLM.from_pretrained("longformer-base-4096")
model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
model.to(torch_device)
# 'Hello world! ' repeated 1000 times
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment