Unverified Commit 03d8527d authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Longformer for question answering (#4500)

* added LongformerForQuestionAnswering

* add LongformerForQuestionAnswering

* fix import for LongformerForMaskedLM

* add LongformerForQuestionAnswering

* hardcoded sep_token_id

* compute attention_mask if not provided

* combine global_attention_mask with attention_mask when provided

* update example in  docstring

* add assert error messages, better attention combine

* add test for longformerForQuestionAnswering

* typo

* cast gloabl_attention_mask to long

* make style

* Update src/transformers/configuration_longformer.py

* Update src/transformers/configuration_longformer.py

* fix the code quality

* Merge branch 'longformer-for-question-answering' of https://github.com/patil-suraj/transformers

 into longformer-for-question-answering
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent a34a9896
...@@ -335,7 +335,12 @@ if is_torch_available(): ...@@ -335,7 +335,12 @@ if is_torch_available():
REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
) )
from .modeling_longformer import LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, LongformerModel, LongformerForMaskedLM from .modeling_longformer import (
LongformerModel,
LongformerForMaskedLM,
LongformerForQuestionAnswering,
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
)
# Optimization # Optimization
from .optimization import ( from .optimization import (
......
...@@ -64,6 +64,7 @@ class LongformerConfig(RobertaConfig): ...@@ -64,6 +64,7 @@ class LongformerConfig(RobertaConfig):
pretrained_config_archive_map = LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "longformer" model_type = "longformer"
def __init__(self, attention_window: Union[List[int], int] = 512, **kwargs): def __init__(self, attention_window: Union[List[int], int] = 512, sep_token_id: int = 2, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.attention_window = attention_window self.attention_window = attention_window
self.sep_token_id = sep_token_id
...@@ -101,7 +101,12 @@ from .modeling_flaubert import ( ...@@ -101,7 +101,12 @@ from .modeling_flaubert import (
FlaubertWithLMHeadModel, FlaubertWithLMHeadModel,
) )
from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Model from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Model
from .modeling_longformer import LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, LongformerForMaskedLM, LongformerModel from .modeling_longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
LongformerForMaskedLM,
LongformerForQuestionAnswering,
LongformerModel,
)
from .modeling_marian import MarianMTModel from .modeling_marian import MarianMTModel
from .modeling_openai import OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTModel from .modeling_openai import OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTModel
from .modeling_reformer import ReformerModel, ReformerModelWithLMHead from .modeling_reformer import ReformerModel, ReformerModelWithLMHead
...@@ -260,6 +265,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( ...@@ -260,6 +265,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
[ [
(DistilBertConfig, DistilBertForQuestionAnswering), (DistilBertConfig, DistilBertForQuestionAnswering),
(AlbertConfig, AlbertForQuestionAnswering), (AlbertConfig, AlbertForQuestionAnswering),
(LongformerConfig, LongformerForQuestionAnswering),
(RobertaConfig, RobertaForQuestionAnswering), (RobertaConfig, RobertaForQuestionAnswering),
(BertConfig, BertForQuestionAnswering), (BertConfig, BertForQuestionAnswering),
(XLNetConfig, XLNetForQuestionAnsweringSimple), (XLNetConfig, XLNetForQuestionAnsweringSimple),
......
...@@ -707,3 +707,144 @@ class LongformerForMaskedLM(BertPreTrainedModel): ...@@ -707,3 +707,144 @@ class LongformerForMaskedLM(BertPreTrainedModel):
outputs = (masked_lm_loss,) + outputs outputs = (masked_lm_loss,) + outputs
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
@add_start_docstrings(
"""Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
LONGFORMER_START_DOCSTRING,
)
class LongformerForQuestionAnswering(BertPreTrainedModel):
config_class = LongformerConfig
pretrained_model_archive_map = LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "longformer"
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.longformer = LongformerModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
def _get_question_end_index(self, input_ids):
sep_token_indices = (input_ids == self.config.sep_token_id).nonzero()
assert sep_token_indices.size(1) == 2, "input_ids should have two dimensions"
assert sep_token_indices.size(0) == 3 * input_ids.size(
0
), "There should be exactly three separator tokens in every sample for questions answering"
return sep_token_indices.view(input_ids.size(0), 3, 2)[:, 0, 1]
def _compute_global_attention_mask(self, input_ids):
question_end_index = self._get_question_end_index(input_ids)
question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1
# bool attention mask with True in locations of global attention
attention_mask = torch.arange(input_ids.size(1), device=input_ids.device)
attention_mask = attention_mask.expand_as(input_ids) < question_end_index
attention_mask = attention_mask.int() + 1 # from True, False to 2, 1
return attention_mask.long()
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING)
def forward(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import LongformerTokenizer, LongformerForQuestionAnswering
import torch
tokenizer = LongformerTokenizer.from_pretrained(longformer-base-4096')
model = LongformerForQuestionAnswering.from_pretrained(longformer-base-4096')
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
encoding = tokenizer.encode_plus(question, text)
input_ids = encoding["input_ids"]
# default is local attention everywhere
# the forward method will automatically set global attention on question tokens
attention_mask = encoding["attention_mask"]
start_scores, end_scores = model(torch.tensor([input_ids]), attention_mask=attention_mask)
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])
"""
# set global attention on question tokens
global_attention_mask = self._compute_global_attention_mask(input_ids)
if attention_mask is None:
attention_mask = global_attention_mask
else:
# combine global_attention_mask with attention_mask
# global attention on question tokens, no attention on padding tokens
attention_mask = global_attention_mask * attention_mask
outputs = self.longformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = (start_logits, end_logits,) + outputs[2:]
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
...@@ -29,6 +29,7 @@ if is_torch_available(): ...@@ -29,6 +29,7 @@ if is_torch_available():
LongformerConfig, LongformerConfig,
LongformerModel, LongformerModel,
LongformerForMaskedLM, LongformerForMaskedLM,
LongformerForQuestionAnswering,
) )
...@@ -171,6 +172,28 @@ class LongformerModelTester(object): ...@@ -171,6 +172,28 @@ class LongformerModelTester(object):
) )
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_longformer_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LongformerForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
loss, start_logits, end_logits = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -185,6 +208,26 @@ class LongformerModelTester(object): ...@@ -185,6 +208,26 @@ class LongformerModelTester(object):
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
return config, inputs_dict return config, inputs_dict
def prepare_config_and_inputs_for_question_answering(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
# Replace sep_token_id by some random id
input_ids[input_ids == config.sep_token_id] = torch.randint(0, config.vocab_size, (1,)).item()
# Make sure there are exactly three sep_token_id
input_ids[:, -3:] = config.sep_token_id
input_mask = torch.ones_like(input_ids)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@require_torch @require_torch
class LongformerModelTest(ModelTesterMixin, unittest.TestCase): class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
...@@ -209,6 +252,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -209,6 +252,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs)
def test_longformer_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs)
class LongformerModelIntegrationTest(unittest.TestCase): class LongformerModelIntegrationTest(unittest.TestCase):
@slow @slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment