Unverified Commit c589eae2 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Longformer For Question Answering] Conversion script, doc, small fixes (#4593)

* add new longformer for question answering model

* add new config as well

* fix links

* fix links part 2
parent a163c9ca
...@@ -67,3 +67,10 @@ LongformerForMaskedLM ...@@ -67,3 +67,10 @@ LongformerForMaskedLM
.. autoclass:: transformers.LongformerForMaskedLM .. autoclass:: transformers.LongformerForMaskedLM
:members: :members:
LongformerForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LongformerForQuestionAnswering
:members:
...@@ -25,6 +25,7 @@ logger = logging.getLogger(__name__) ...@@ -25,6 +25,7 @@ logger = logging.getLogger(__name__)
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = { LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/config.json", "longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/config.json",
"longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/config.json", "longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/config.json",
"longformer-large-4096-finetuned-triviaqa": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-finetuned-triviaqa/config.json",
} }
......
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert RoBERTa checkpoint."""
import argparse
import pytorch_lightning as pl
import torch
from transformers.modeling_longformer import LongformerForQuestionAnswering, LongformerModel
class LightningModel(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
self.num_labels = 2
self.qa_outputs = torch.nn.Linear(self.model.config.hidden_size, self.num_labels)
# implement only because lighning requires to do so
def forward(self):
pass
def convert_longformer_qa_checkpoint_to_pytorch(
longformer_model: str, longformer_question_answering_ckpt_path: str, pytorch_dump_folder_path: str
):
# load longformer model from model identifier
longformer = LongformerModel.from_pretrained(longformer_model)
lightning_model = LightningModel(longformer)
ckpt = torch.load(longformer_question_answering_ckpt_path, map_location=torch.device("cpu"))
lightning_model.load_state_dict(ckpt["state_dict"])
# init longformer question answering model
longformer_for_qa = LongformerForQuestionAnswering.from_pretrained(longformer_model)
# transfer weights
longformer_for_qa.longformer.load_state_dict(lightning_model.model.state_dict())
longformer_for_qa.qa_outputs.load_state_dict(lightning_model.qa_outputs.state_dict())
longformer_for_qa.eval()
# save model
longformer_for_qa.save_pretrained(pytorch_dump_folder_path)
print("Conversion succesful. Model saved under {}".format(pytorch_dump_folder_path))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--longformer_model",
default=None,
type=str,
required=True,
help="model identifier of longformer. Should be either `longformer-base-4096` or `longformer-large-4096`.",
)
parser.add_argument(
"--longformer_question_answering_ckpt_path",
default=None,
type=str,
required=True,
help="Path the official PyTorch Lighning Checkpoint.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_longformer_qa_checkpoint_to_pytorch(
args.longformer_model, args.longformer_question_answering_ckpt_path, args.pytorch_dump_folder_path
)
...@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__) ...@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = { LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {
"longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/pytorch_model.bin", "longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/pytorch_model.bin",
"longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/pytorch_model.bin", "longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/pytorch_model.bin",
"longformer-large-4096-finetuned-triviaqa": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-finetuned-triviaqa/pytorch_model.bin",
} }
...@@ -710,7 +711,7 @@ class LongformerForMaskedLM(BertPreTrainedModel): ...@@ -710,7 +711,7 @@ class LongformerForMaskedLM(BertPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
"""Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of """Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD / TriviaQA (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """, the hidden-states output to compute `span start logits` and `span end logits`). """,
LONGFORMER_START_DOCSTRING, LONGFORMER_START_DOCSTRING,
) )
...@@ -728,26 +729,27 @@ class LongformerForQuestionAnswering(BertPreTrainedModel): ...@@ -728,26 +729,27 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
self.init_weights() self.init_weights()
def _get_question_end_index(self, input_ids):
sep_token_indices = (input_ids == self.config.sep_token_id).nonzero()
assert sep_token_indices.size(1) == 2, "input_ids should have two dimensions"
assert sep_token_indices.size(0) == 3 * input_ids.size(
0
), "There should be exactly three separator tokens in every sample for questions answering"
return sep_token_indices.view(input_ids.size(0), 3, 2)[:, 0, 1]
def _compute_global_attention_mask(self, input_ids): def _compute_global_attention_mask(self, input_ids):
question_end_index = self._get_question_end_index(input_ids) question_end_index = self._get_question_end_index(input_ids)
question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1 question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1
# bool attention mask with True in locations of global attention # bool attention mask with True in locations of global attention
attention_mask = torch.arange(input_ids.size(1), device=input_ids.device) attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
attention_mask = attention_mask.expand_as(input_ids) < question_end_index attention_mask = attention_mask.expand_as(input_ids) < question_end_index
attention_mask = attention_mask.int() + 1 # from True, False to 2, 1 attention_mask = attention_mask.int() + 1 # True => global attention; False => local attention
return attention_mask.long() return attention_mask.long()
def _get_question_end_index(self, input_ids):
sep_token_indices = (input_ids == self.config.sep_token_id).nonzero()
batch_size = input_ids.shape[0]
assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
assert (
sep_token_indices.shape[0] == 3 * batch_size
), f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering"
return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
...@@ -769,7 +771,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel): ...@@ -769,7 +771,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
Positions are clamped to the length of the sequence (`sequence_length`). Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss. Position outside of the sequence are not taken into account for computing the loss.
Returns: Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs: :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.LongformerConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
...@@ -785,24 +787,29 @@ class LongformerForQuestionAnswering(BertPreTrainedModel): ...@@ -785,24 +787,29 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
Examples:: Examples::
from transformers import LongformerTokenizer, LongformerForQuestionAnswering from transformers import LongformerTokenizer, LongformerForQuestionAnswering
import torch import torch
tokenizer = LongformerTokenizer.from_pretrained(longformer-base-4096') tokenizer = LongformerTokenizer.from_pretrained("longformer-large-4096-finetuned-triviaqa")
model = LongformerForQuestionAnswering.from_pretrained(longformer-base-4096') model = LongformerForQuestionAnswering.from_pretrained("longformer-large-4096-finetuned-triviaqa")
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
encoding = tokenizer.encode_plus(question, text) encoding = tokenizer.encode_plus(question, text, return_tensors="pt")
input_ids = encoding["input_ids"] input_ids = encoding["input_ids"]
# default is local attention everywhere # default is local attention everywhere
# the forward method will automatically set global attention on question tokens # the forward method will automatically set global attention on question tokens
attention_mask = encoding["attention_mask"] attention_mask = encoding["attention_mask"]
start_scores, end_scores = model(torch.tensor([input_ids]), attention_mask=attention_mask) start_scores, end_scores = model(input_ids, attention_mask=attention_mask)
all_tokens = tokenizer.convert_ids_to_tokens(input_ids) all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])
answer_tokens = all_tokens[torch.argmax(start_scores) :torch.argmax(end_scores)+1]
answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens)) # remove space prepending space token
""" """
# set global attention on question tokens # set global attention on question tokens
......
...@@ -24,12 +24,13 @@ logger = logging.getLogger(__name__) ...@@ -24,12 +24,13 @@ logger = logging.getLogger(__name__)
# vocab and merges same as roberta # vocab and merges same as roberta
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json" vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt" merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
_all_longformer_models = ["longformer-base-4096", "longformer-large-4096"] _all_longformer_models = ["longformer-base-4096", "longformer-large-4096", "longformer-large-4096-finetuned-triviaqa"]
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"longformer-base-4096": 4096, "longformer-base-4096": 4096,
"longformer-large-4096": 4096, "longformer-large-4096": 4096,
"longformer-large-4096-finetuned-triviaqa": 4096,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment