Unverified Commit 8f1d0471 authored by Iz Beltagy's avatar Iz Beltagy Committed by GitHub
Browse files

Longformer (#4352)

* first commit

* bug fixes

* better examples

* undo padding

* remove wrong VOCAB_FILES_NAMES

* License

* make style

* make isort happy

* unit tests

* integration test

* make `black` happy by undoing `isort` changes!!

* lint

* no need for the padding value

* batch_size not bsz

* remove unused type casting

* seqlen not seq_len

* staticmethod

* `bert` selfattention instead of `n2`

* uint8 instead of bool + lints

* pad inputs_embeds using embeddings not a constant

* black

* unit test with padding

* fix unit tests

* remove redundant unit test

* upload model weights

* resolve todo

* simpler _mask_invalid_locations without lru_cache + backward compatible masked_fill_

* increase unittest coverage
parent 31eedff5
...@@ -165,8 +165,9 @@ At some point in the future, you'll be able to seamlessly move from pre-training ...@@ -165,8 +165,9 @@ At some point in the future, you'll be able to seamlessly move from pre-training
18. **[DialoGPT](https://huggingface.co/transformers/model_doc/dialogpt.html)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. 18. **[DialoGPT](https://huggingface.co/transformers/model_doc/dialogpt.html)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
19. **[Reformer](https://huggingface.co/transformers/model_doc/reformer.html)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 19. **[Reformer](https://huggingface.co/transformers/model_doc/reformer.html)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
20. **[MarianMT](https://huggingface.co/transformers/model_doc/marian.html)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. 20. **[MarianMT](https://huggingface.co/transformers/model_doc/marian.html)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
21. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users). 21. **[Longformer](https://huggingface.co/transformers/model_doc/longformer.html)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
22. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR. 22. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users).
23. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html). These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html).
......
...@@ -305,3 +305,9 @@ For a list that includes community-uploaded models, refer to `https://huggingfac ...@@ -305,3 +305,9 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
| MarianMT | ``Helsinki-NLP/opus-mt-{src}-{tgt}`` | | 12-layer, 512-hidden, 8-heads, ~74M parameter Machine translation models. Parameter counts vary depending on vocab size. | | MarianMT | ``Helsinki-NLP/opus-mt-{src}-{tgt}`` | | 12-layer, 512-hidden, 8-heads, ~74M parameter Machine translation models. Parameter counts vary depending on vocab size. |
| | | | (see `model list <https://huggingface.co/Helsinki-NLP>`_) | | | | | (see `model list <https://huggingface.co/Helsinki-NLP>`_) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Longformer | ``longformer-base-4096`` | | 12-layer, 768-hidden, 12-heads, ~149M parameters |
| | | | Starting from RoBERTa-base checkpoint, trained on documents of max length 4,096 |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``longformer-large-4096`` | | 24-layer, 1024-hidden, 16-heads, ~435M parameters |
| | | | Starting from RoBERTa-large checkpoint, trained on documents of max length 4,096 |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
...@@ -44,6 +44,7 @@ from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, Electr ...@@ -44,6 +44,7 @@ from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, Electr
from .configuration_encoder_decoder import EncoderDecoderConfig from .configuration_encoder_decoder import EncoderDecoderConfig
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
from .configuration_marian import MarianConfig from .configuration_marian import MarianConfig
from .configuration_mmbt import MMBTConfig from .configuration_mmbt import MMBTConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
...@@ -138,6 +139,7 @@ from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFas ...@@ -138,6 +139,7 @@ from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFas
from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
from .tokenization_flaubert import FlaubertTokenizer from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_longformer import LongformerTokenizer
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_reformer import ReformerTokenizer from .tokenization_reformer import ReformerTokenizer
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
...@@ -332,6 +334,8 @@ if is_torch_available(): ...@@ -332,6 +334,8 @@ if is_torch_available():
REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
) )
from .modeling_longformer import LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, LongformerModel, LongformerForMaskedLM
# Optimization # Optimization
from .optimization import ( from .optimization import (
AdamW, AdamW,
......
...@@ -28,6 +28,7 @@ from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, Electr ...@@ -28,6 +28,7 @@ from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, Electr
from .configuration_encoder_decoder import EncoderDecoderConfig from .configuration_encoder_decoder import EncoderDecoderConfig
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
from .configuration_marian import MarianConfig from .configuration_marian import MarianConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_reformer import ReformerConfig from .configuration_reformer import ReformerConfig
...@@ -62,6 +63,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( ...@@ -62,6 +63,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
] ]
for key, value, in pretrained_map.items() for key, value, in pretrained_map.items()
) )
...@@ -77,6 +79,7 @@ CONFIG_MAPPING = OrderedDict( ...@@ -77,6 +79,7 @@ CONFIG_MAPPING = OrderedDict(
("marian", MarianConfig,), ("marian", MarianConfig,),
("bart", BartConfig,), ("bart", BartConfig,),
("reformer", ReformerConfig,), ("reformer", ReformerConfig,),
("longformer", LongformerConfig,),
("roberta", RobertaConfig,), ("roberta", RobertaConfig,),
("flaubert", FlaubertConfig,), ("flaubert", FlaubertConfig,),
("bert", BertConfig,), ("bert", BertConfig,),
...@@ -133,6 +136,7 @@ class AutoConfig: ...@@ -133,6 +136,7 @@ class AutoConfig:
- contains `albert`: :class:`~transformers.AlbertConfig` (ALBERT model) - contains `albert`: :class:`~transformers.AlbertConfig` (ALBERT model)
- contains `camembert`: :class:`~transformers.CamembertConfig` (CamemBERT model) - contains `camembert`: :class:`~transformers.CamembertConfig` (CamemBERT model)
- contains `xlm-roberta`: :class:`~transformers.XLMRobertaConfig` (XLM-RoBERTa model) - contains `xlm-roberta`: :class:`~transformers.XLMRobertaConfig` (XLM-RoBERTa model)
- contains `longformer`: :class:`~transformers.LongformerConfig` (Longformer model)
- contains `roberta`: :class:`~transformers.RobertaConfig` (RoBERTa model) - contains `roberta`: :class:`~transformers.RobertaConfig` (RoBERTa model)
- contains `reformer`: :class:`~transformers.ReformerConfig` (Reformer model) - contains `reformer`: :class:`~transformers.ReformerConfig` (Reformer model)
- contains `bert`: :class:`~transformers.BertConfig` (Bert model) - contains `bert`: :class:`~transformers.BertConfig` (Bert model)
...@@ -145,7 +149,6 @@ class AutoConfig: ...@@ -145,7 +149,6 @@ class AutoConfig:
- contains `flaubert` : :class:`~transformers.FlaubertConfig` (Flaubert model) - contains `flaubert` : :class:`~transformers.FlaubertConfig` (Flaubert model)
- contains `electra` : :class:`~transformers.ElectraConfig` (ELECTRA model) - contains `electra` : :class:`~transformers.ElectraConfig` (ELECTRA model)
Args: Args:
pretrained_model_name_or_path (:obj:`string`): pretrained_model_name_or_path (:obj:`string`):
Is either: \ Is either: \
......
# coding=utf-8
# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Longformer configuration """
import logging
from typing import List, Union
from .configuration_roberta import RobertaConfig
logger = logging.getLogger(__name__)
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/config.json",
"longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/config.json",
}
class LongformerConfig(RobertaConfig):
r"""
This is the configuration class to store the configuration of an :class:`~transformers.LongformerModel`.
It is used to instantiate an Longformer model according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
the RoBERTa `roberta-base <https://huggingface.co/roberta-base>`__ architecture with a sequence length 4,096.
The :class:`~transformers.LongformerConfig` class directly inherits :class:`~transformers.RobertaConfig`.
It reuses the same defaults. Please check the parent class for more information.
Example::
from transformers import LongformerConfig, LongformerModel
# Initializing a Longformer configuration
configuration = LongformerConfig()
# Initializing a model from the configuration
model = LongformerModel(configuration)
# Accessing the model configuration
configuration = model.config
Attributes:
pretrained_config_archive_map (Dict[str, str]):
A dictionary containing all the available pre-trained checkpoints.
"""
pretrained_config_archive_map = LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "longformer"
def __init__(self, attention_window: Union[List[int], int] = 512, attention_mode: str = "longformer", **kwargs):
"""
Args:
attention_window (:obj:`int` or :obj:`List[int]`, optional, defaults to 512):
Size of an attention window around each token. If :obj:`int`, use the same size for all layers.
To specify a different window size for each layer, use a :obj:`List[int]` where
`len(attention_window) == num_hidden_layers`.
attention_mode (:obj:`str`, optional, possible values ['longformer', 'bert'], defaults to 'longformer'):
Type of selfattention. Use 'longformer' for :obj:`LongformerSelfAttention` or 'bert' for
standard BERT full n^2 self attention using :obj:`modeling_bert.BertSelfAttention`. Note that full n^2
selfattention is supported just for comparison, but it will OOM for long sequences.
"""
super().__init__(**kwargs)
self.attention_window = attention_window
self.attention_mode = attention_mode
...@@ -30,6 +30,7 @@ from .configuration_auto import ( ...@@ -30,6 +30,7 @@ from .configuration_auto import (
EncoderDecoderConfig, EncoderDecoderConfig,
FlaubertConfig, FlaubertConfig,
GPT2Config, GPT2Config,
LongformerConfig,
OpenAIGPTConfig, OpenAIGPTConfig,
ReformerConfig, ReformerConfig,
RobertaConfig, RobertaConfig,
...@@ -99,6 +100,7 @@ from .modeling_flaubert import ( ...@@ -99,6 +100,7 @@ from .modeling_flaubert import (
FlaubertWithLMHeadModel, FlaubertWithLMHeadModel,
) )
from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Model from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Model
from .modeling_longformer import LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, LongformerForMaskedLM, LongformerModel
from .modeling_marian import MarianMTModel from .modeling_marian import MarianMTModel
from .modeling_openai import OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTModel from .modeling_openai import OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTModel
from .modeling_reformer import ReformerModel, ReformerModelWithLMHead from .modeling_reformer import ReformerModel, ReformerModelWithLMHead
...@@ -162,6 +164,7 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict( ...@@ -162,6 +164,7 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP, FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP, ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
] ]
for key, value, in pretrained_map.items() for key, value, in pretrained_map.items()
) )
...@@ -174,6 +177,7 @@ MODEL_MAPPING = OrderedDict( ...@@ -174,6 +177,7 @@ MODEL_MAPPING = OrderedDict(
(CamembertConfig, CamembertModel), (CamembertConfig, CamembertModel),
(XLMRobertaConfig, XLMRobertaModel), (XLMRobertaConfig, XLMRobertaModel),
(BartConfig, BartModel), (BartConfig, BartModel),
(LongformerConfig, LongformerModel),
(RobertaConfig, RobertaModel), (RobertaConfig, RobertaModel),
(BertConfig, BertModel), (BertConfig, BertModel),
(OpenAIGPTConfig, OpenAIGPTModel), (OpenAIGPTConfig, OpenAIGPTModel),
...@@ -196,6 +200,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( ...@@ -196,6 +200,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
(CamembertConfig, CamembertForMaskedLM), (CamembertConfig, CamembertForMaskedLM),
(XLMRobertaConfig, XLMRobertaForMaskedLM), (XLMRobertaConfig, XLMRobertaForMaskedLM),
(BartConfig, BartForConditionalGeneration), (BartConfig, BartForConditionalGeneration),
(LongformerConfig, LongformerForMaskedLM),
(RobertaConfig, RobertaForMaskedLM), (RobertaConfig, RobertaForMaskedLM),
(BertConfig, BertForPreTraining), (BertConfig, BertForPreTraining),
(OpenAIGPTConfig, OpenAIGPTLMHeadModel), (OpenAIGPTConfig, OpenAIGPTLMHeadModel),
...@@ -218,6 +223,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( ...@@ -218,6 +223,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(XLMRobertaConfig, XLMRobertaForMaskedLM), (XLMRobertaConfig, XLMRobertaForMaskedLM),
(MarianConfig, MarianMTModel), (MarianConfig, MarianMTModel),
(BartConfig, BartForConditionalGeneration), (BartConfig, BartForConditionalGeneration),
(LongformerConfig, LongformerForMaskedLM),
(RobertaConfig, RobertaForMaskedLM), (RobertaConfig, RobertaForMaskedLM),
(BertConfig, BertForMaskedLM), (BertConfig, BertForMaskedLM),
(OpenAIGPTConfig, OpenAIGPTLMHeadModel), (OpenAIGPTConfig, OpenAIGPTLMHeadModel),
...@@ -313,6 +319,7 @@ class AutoModel: ...@@ -313,6 +319,7 @@ class AutoModel:
The model class to instantiate is selected based on the configuration class: The model class to instantiate is selected based on the configuration class:
- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModel` (DistilBERT model) - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModel` (DistilBERT model)
- isInstance of `longformer` configuration class: :class:`~transformers.LongformerModel` (Longformer model)
- isInstance of `roberta` configuration class: :class:`~transformers.RobertaModel` (RoBERTa model) - isInstance of `roberta` configuration class: :class:`~transformers.RobertaModel` (RoBERTa model)
- isInstance of `bert` configuration class: :class:`~transformers.BertModel` (Bert model) - isInstance of `bert` configuration class: :class:`~transformers.BertModel` (Bert model)
- isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTModel` (OpenAI GPT model) - isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTModel` (OpenAI GPT model)
...@@ -355,6 +362,7 @@ class AutoModel: ...@@ -355,6 +362,7 @@ class AutoModel:
- contains `albert`: :class:`~transformers.AlbertModel` (ALBERT model) - contains `albert`: :class:`~transformers.AlbertModel` (ALBERT model)
- contains `camembert`: :class:`~transformers.CamembertModel` (CamemBERT model) - contains `camembert`: :class:`~transformers.CamembertModel` (CamemBERT model)
- contains `xlm-roberta`: :class:`~transformers.XLMRobertaModel` (XLM-RoBERTa model) - contains `xlm-roberta`: :class:`~transformers.XLMRobertaModel` (XLM-RoBERTa model)
- contains `longformer` :class:`~transformers.LongformerModel` (Longformer model)
- contains `roberta`: :class:`~transformers.RobertaModel` (RoBERTa model) - contains `roberta`: :class:`~transformers.RobertaModel` (RoBERTa model)
- contains `bert`: :class:`~transformers.BertModel` (Bert model) - contains `bert`: :class:`~transformers.BertModel` (Bert model)
- contains `openai-gpt`: :class:`~transformers.OpenAIGPTModel` (OpenAI GPT model) - contains `openai-gpt`: :class:`~transformers.OpenAIGPTModel` (OpenAI GPT model)
...@@ -463,6 +471,7 @@ class AutoModelForPreTraining: ...@@ -463,6 +471,7 @@ class AutoModelForPreTraining:
The model class to instantiate is selected based on the configuration class: The model class to instantiate is selected based on the configuration class:
- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model) - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
- isInstance of `longformer` configuration class: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
- isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model) - isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
- isInstance of `bert` configuration class: :class:`~transformers.BertForPreTraining` (Bert model) - isInstance of `bert` configuration class: :class:`~transformers.BertForPreTraining` (Bert model)
- isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model) - isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
...@@ -504,6 +513,7 @@ class AutoModelForPreTraining: ...@@ -504,6 +513,7 @@ class AutoModelForPreTraining:
- contains `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model) - contains `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
- contains `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model) - contains `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
- contains `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model) - contains `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
- contains `longformer`: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
- contains `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model) - contains `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
- contains `bert`: :class:`~transformers.BertForPreTraining` (Bert model) - contains `bert`: :class:`~transformers.BertForPreTraining` (Bert model)
- contains `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model) - contains `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
...@@ -606,6 +616,7 @@ class AutoModelWithLMHead: ...@@ -606,6 +616,7 @@ class AutoModelWithLMHead:
The model class to instantiate is selected based on the configuration class: The model class to instantiate is selected based on the configuration class:
- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model) - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
- isInstance of `longformer` configuration class: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
- isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model) - isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
- isInstance of `bert` configuration class: :class:`~transformers.BertForMaskedLM` (Bert model) - isInstance of `bert` configuration class: :class:`~transformers.BertForMaskedLM` (Bert model)
- isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model) - isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
...@@ -648,6 +659,7 @@ class AutoModelWithLMHead: ...@@ -648,6 +659,7 @@ class AutoModelWithLMHead:
- contains `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model) - contains `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
- contains `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model) - contains `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
- contains `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model) - contains `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
- contains `longformer`: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
- contains `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model) - contains `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
- contains `bert`: :class:`~transformers.BertForMaskedLM` (Bert model) - contains `bert`: :class:`~transformers.BertForMaskedLM` (Bert model)
- contains `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model) - contains `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
......
# coding=utf-8
# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Longformer model. """
import logging
import math
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from .configuration_longformer import LongformerConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_bert import BertPreTrainedModel
from .modeling_roberta import RobertaLMHead, RobertaModel
logger = logging.getLogger(__name__)
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {
"longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/pytorch_model.bin",
"longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/pytorch_model.bin",
}
class LongformerSelfAttention(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.output_attentions = config.output_attentions
self.num_heads = config.num_attention_heads
self.head_dim = int(config.hidden_size / config.num_attention_heads)
self.embed_dim = config.hidden_size
self.query = nn.Linear(config.hidden_size, self.embed_dim)
self.key = nn.Linear(config.hidden_size, self.embed_dim)
self.value = nn.Linear(config.hidden_size, self.embed_dim)
# separate projection layers for tokens with global attention
self.query_global = nn.Linear(config.hidden_size, self.embed_dim)
self.key_global = nn.Linear(config.hidden_size, self.embed_dim)
self.value_global = nn.Linear(config.hidden_size, self.embed_dim)
self.dropout = config.attention_probs_dropout_prob
self.layer_id = layer_id
attention_window = config.attention_window[self.layer_id]
assert (
attention_window % 2 == 0
), f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}"
assert (
attention_window > 0
), f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}"
self.one_sided_attention_window_size = attention_window // 2
@staticmethod
def _skew(x, direction):
"""Convert diagonals into columns (or columns into diagonals depending on `direction`"""
x_padded = F.pad(x, direction) # padding value is not important because it will be overwritten
x_padded = x_padded.view(*x_padded.size()[:-2], x_padded.size(-1), x_padded.size(-2))
return x_padded
@staticmethod
def _skew2(x):
"""shift every row 1 step to right converting columns into diagonals"""
# X = B x C x M x L
B, C, M, L = x.size()
x = F.pad(x, (0, M + 1)) # B x C x M x (L+M+1). Padding value is not important because it'll be overwritten
x = x.view(B, C, -1) # B x C x ML+MM+M
x = x[:, :, :-M] # B x C x ML+MM
x = x.view(B, C, M, M + L) # B x C, M x L+M
x = x[:, :, :, :-1]
return x
@staticmethod
def _chunk(x, w):
"""convert into overlapping chunkings. Chunk size = 2w, overlap size = w"""
# non-overlapping chunks of size = 2w
x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2))
# use `as_strided` to make the chunks overlap with an overlap size = w
chunk_size = list(x.size())
chunk_size[1] = chunk_size[1] * 2 - 1
chunk_stride = list(x.stride())
chunk_stride[1] = chunk_stride[1] // 2
return x.as_strided(size=chunk_size, stride=chunk_stride)
def _mask_invalid_locations(self, input_tensor, w) -> torch.Tensor:
affected_seqlen = w
beginning_mask_2d = input_tensor.new_ones(w, w + 1).tril().flip(dims=[0])
beginning_mask = beginning_mask_2d[None, :, None, :]
ending_mask = beginning_mask.flip(dims=(1, 3))
seqlen = input_tensor.size(1)
beginning_input = input_tensor[:, :affected_seqlen, :, : w + 1]
beginning_mask = beginning_mask[:, :seqlen].expand(beginning_input.size())
beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
ending_input = input_tensor[:, -affected_seqlen:, :, -(w + 1) :]
ending_mask = ending_mask[:, -seqlen:].expand(ending_input.size())
ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
def _sliding_chunks_matmul_qk(self, q: torch.Tensor, k: torch.Tensor, w: int):
"""Matrix multiplicatio of query x key tensors using with a sliding window attention pattern.
This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer)
with an overlap of size w"""
batch_size, seqlen, num_heads, head_dim = q.size()
assert seqlen % (w * 2) == 0, f"Sequence length should be multiple of {w * 2}. Given {seqlen}"
assert q.size() == k.size()
chunks_count = seqlen // w - 1
# group batch_size and num_heads dimensions into one, then chunk seqlen into chunks of size w * 2
q = q.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim)
k = k.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim)
chunk_q = self._chunk(q, w)
chunk_k = self._chunk(k, w)
# matrix multipication
# bcxd: batch_size * num_heads x chunks x 2w x head_dim
# bcyd: batch_size * num_heads x chunks x 2w x head_dim
# bcxy: batch_size * num_heads x chunks x 2w x 2w
chunk_attn = torch.einsum("bcxd,bcyd->bcxy", (chunk_q, chunk_k)) # multiply
# convert diagonals into columns
diagonal_chunk_attn = self._skew(chunk_attn, direction=(0, 0, 0, 1))
# allocate space for the overall attention matrix where the chunks are compined. The last dimension
# has (w * 2 + 1) columns. The first (w) columns are the w lower triangles (attention from a word to
# w previous words). The following column is attention score from each word to itself, then
# followed by w columns for the upper triangle.
diagonal_attn = diagonal_chunk_attn.new_empty((batch_size * num_heads, chunks_count + 1, w, w * 2 + 1))
# copy parts from diagonal_chunk_attn into the compined matrix of attentions
# - copying the main diagonal and the upper triangle
diagonal_attn[:, :-1, :, w:] = diagonal_chunk_attn[:, :, :w, : w + 1]
diagonal_attn[:, -1, :, w:] = diagonal_chunk_attn[:, -1, w:, : w + 1]
# - copying the lower triangle
diagonal_attn[:, 1:, :, :w] = diagonal_chunk_attn[:, :, -(w + 1) : -1, w + 1 :]
diagonal_attn[:, 0, 1:w, 1:w] = diagonal_chunk_attn[:, 0, : w - 1, 1 - w :]
# separate batch_size and num_heads dimensions again
diagonal_attn = diagonal_attn.view(batch_size, num_heads, seqlen, 2 * w + 1).transpose(2, 1)
self._mask_invalid_locations(diagonal_attn, w)
return diagonal_attn
def _sliding_chunks_matmul_pv(self, prob: torch.Tensor, v: torch.Tensor, w: int):
"""Same as _sliding_chunks_matmul_qk but for prob and value tensors. It is expecting the same output
format from _sliding_chunks_matmul_qk"""
batch_size, seqlen, num_heads, head_dim = v.size()
assert seqlen % (w * 2) == 0
assert prob.size()[:3] == v.size()[:3]
assert prob.size(3) == 2 * w + 1
chunks_count = seqlen // w - 1
# group batch_size and num_heads dimensions into one, then chunk seqlen into chunks of size 2w
chunk_prob = prob.transpose(1, 2).reshape(batch_size * num_heads, seqlen // w, w, 2 * w + 1)
# group batch_size and num_heads dimensions into one
v = v.transpose(1, 2).reshape(batch_size * num_heads, seqlen, head_dim)
# pad seqlen with w at the beginning of the sequence and another w at the end
padded_v = F.pad(v, (0, 0, w, w), value=-1)
# chunk padded_v into chunks of size 3w and an overlap of size w
chunk_v_size = (batch_size * num_heads, chunks_count + 1, 3 * w, head_dim)
chunk_v_stride = padded_v.stride()
chunk_v_stride = chunk_v_stride[0], w * chunk_v_stride[1], chunk_v_stride[1], chunk_v_stride[2]
chunk_v = padded_v.as_strided(size=chunk_v_size, stride=chunk_v_stride)
skewed_prob = self._skew2(chunk_prob)
context = torch.einsum("bcwd,bcdh->bcwh", (skewed_prob, chunk_v))
return context.view(batch_size, num_heads, seqlen, head_dim).transpose(1, 2)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
"""
LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`.
Padding to `attention_window` happens in LongformerModel.forward to avoid redoing the padding on each layer.
The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to
-ve: no attention
0: local attention
+ve: global attention
`encoder_hidden_states` and `encoder_attention_mask` are not supported and should be None
"""
# TODO: add support for `encoder_hidden_states` and `encoder_attention_mask`
assert encoder_hidden_states is None, "`encoder_hidden_states` is not supported and should be None"
assert encoder_attention_mask is None, "`encoder_attention_mask` is not supported and shiould be None"
if attention_mask is not None:
attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1)
key_padding_mask = attention_mask < 0
extra_attention_mask = attention_mask > 0
remove_from_windowed_attention_mask = attention_mask != 0
num_extra_indices_per_batch = extra_attention_mask.long().sum(dim=1)
max_num_extra_indices_per_batch = num_extra_indices_per_batch.max()
if max_num_extra_indices_per_batch <= 0:
extra_attention_mask = None
else:
# To support the case of variable number of global attention in the rows of a batch,
# we use the following three selection masks to select global attention embeddings
# in a 3d tensor and pad it to `max_num_extra_indices_per_batch`
# 1) selecting embeddings that correspond to global attention
extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True)
zero_to_max_range = torch.arange(
0, max_num_extra_indices_per_batch, device=num_extra_indices_per_batch.device
)
# mask indicating which values are actually going to be padding
selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1)
# 2) location of the non-padding values in the selected global attention
selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True)
# 3) location of the padding values in the selected global attention
selection_padding_mask_zeros = (selection_padding_mask == 0).nonzero(as_tuple=True)
else:
remove_from_windowed_attention_mask = None
extra_attention_mask = None
key_padding_mask = None
hidden_states = hidden_states.transpose(0, 1)
seqlen, batch_size, embed_dim = hidden_states.size()
assert embed_dim == self.embed_dim
q = self.query(hidden_states)
k = self.key(hidden_states)
v = self.value(hidden_states)
q /= math.sqrt(self.head_dim)
q = q.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
k = k.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
# attn_weights = (batch_size, seqlen, num_heads, window*2+1)
attn_weights = self._sliding_chunks_matmul_qk(q, k, self.one_sided_attention_window_size)
self._mask_invalid_locations(attn_weights, self.one_sided_attention_window_size)
if remove_from_windowed_attention_mask is not None:
# This implementation is fast and takes very little memory because num_heads x hidden_size = 1
# from (batch_size x seqlen) to (batch_size x seqlen x num_heads x hidden_size)
remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze(
dim=-1
)
# cast to fp32/fp16 then replace 1's with -inf
float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill(
remove_from_windowed_attention_mask, -10000.0
)
ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones
# diagonal mask with zeros everywhere and -inf inplace of padding
d_mask = self._sliding_chunks_matmul_qk(ones, float_mask, self.one_sided_attention_window_size)
attn_weights += d_mask
assert list(attn_weights.size()) == [
batch_size,
seqlen,
self.num_heads,
self.one_sided_attention_window_size * 2 + 1,
]
# the extra attention
if extra_attention_mask is not None:
selected_k = k.new_zeros(batch_size, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros]
# (batch_size, seqlen, num_heads, max_num_extra_indices_per_batch)
selected_attn_weights = torch.einsum("blhd,bshd->blhs", (q, selected_k))
selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000
# concat to attn_weights
# (batch_size, seqlen, num_heads, extra attention count + 2*window+1)
attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)
attn_weights_fp32 = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
attn_weights = attn_weights_fp32.type_as(attn_weights)
if key_padding_mask is not None:
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_weights = torch.masked_fill(attn_weights, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0)
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
v = v.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
attn = None
if extra_attention_mask is not None:
selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch)
selected_v = v.new_zeros(batch_size, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros]
# use `matmul` because `einsum` crashes sometimes with fp16
# attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
attn = torch.matmul(
selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)
).transpose(1, 2)
attn_probs = attn_probs.narrow(
-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch
).contiguous()
if attn is None:
attn = self._sliding_chunks_matmul_pv(attn_probs, v, self.one_sided_attention_window_size)
else:
attn += self._sliding_chunks_matmul_pv(attn_probs, v, self.one_sided_attention_window_size)
assert attn.size() == (batch_size, seqlen, self.num_heads, self.head_dim), "Unexpected size"
attn = attn.transpose(0, 1).reshape(seqlen, batch_size, embed_dim).contiguous()
# For this case, we'll just recompute the attention for these indices
# and overwrite the attn tensor.
# TODO: remove the redundant computation
if extra_attention_mask is not None:
selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, batch_size, embed_dim)
selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[
extra_attention_mask_nonzeros[::-1]
]
q = self.query_global(selected_hidden_states)
k = self.key_global(hidden_states)
v = self.value_global(hidden_states)
q /= math.sqrt(self.head_dim)
q = (
q.contiguous()
.view(max_num_extra_indices_per_batch, batch_size * self.num_heads, self.head_dim)
.transpose(0, 1)
) # (batch_size * self.num_heads, max_num_extra_indices_per_batch, head_dim)
k = (
k.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
) # batch_size * self.num_heads, seqlen, head_dim)
v = (
v.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
) # batch_size * self.num_heads, seqlen, head_dim)
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [batch_size * self.num_heads, max_num_extra_indices_per_batch, seqlen]
attn_weights = attn_weights.view(batch_size, self.num_heads, max_num_extra_indices_per_batch, seqlen)
attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0
if key_padding_mask is not None:
attn_weights = attn_weights.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), -10000.0,)
attn_weights = attn_weights.view(batch_size * self.num_heads, max_num_extra_indices_per_batch, seqlen)
attn_weights_float = F.softmax(
attn_weights, dim=-1, dtype=torch.float32
) # use fp32 for numerical stability
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
selected_attn = torch.bmm(attn_probs, v)
assert list(selected_attn.size()) == [
batch_size * self.num_heads,
max_num_extra_indices_per_batch,
self.head_dim,
]
selected_attn_4d = selected_attn.view(
batch_size, self.num_heads, max_num_extra_indices_per_batch, self.head_dim
)
nonzero_selected_attn = selected_attn_4d[
selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1]
]
attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(
len(selection_padding_mask_nonzeros[0]), -1
).type_as(hidden_states)
context_layer = attn.transpose(0, 1)
if self.output_attentions:
if extra_attention_mask is not None:
# With global attention, return global attention probabilities only
# batch_size x num_heads x max_num_global_attention_tokens x sequence_length
# which is the attention weights from tokens with global attention to all tokens
# It doesn't not return local attention
# In case of variable number of global attantion in the rows of a batch,
# attn_weights are padded with -10000.0 attention scores
attn_weights = attn_weights.view(batch_size, self.num_heads, max_num_extra_indices_per_batch, seqlen)
else:
# without global attention, return local attention probabilities
# batch_size x num_heads x sequence_length x window_size
# which is the attention weights of every token attending to its neighbours
attn_weights = attn_weights.permute(0, 2, 1, 3)
outputs = (context_layer, attn_weights) if self.output_attentions else (context_layer,)
return outputs
LONGFORMER_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
usage and behavior.
Parameters:
config (:class:`~transformers.LongformerConfig`): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
LONGFORMER_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.LonmgformerTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to decide the attention given on each token, local attention, global attenion, or no attention (for padding tokens).
Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for
task-specific finetuning because it makes the model more flexible at representing the task. For example,
for classification, the <s> token should be given global attention. For QA, all question tokens should also have
global attention. Please refer to the Longformer paper https://arxiv.org/abs/2004.05150 for more details.
Mask values selected in ``[0, 1, 2]``:
``0`` for no attention (padding tokens),
``1`` for local attention (a sliding window attention),
``2`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
`What are position IDs? <../glossary.html#position-ids>`_
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings(
"The bare Longformer Model outputting raw hidden-states without any specific head on top.",
LONGFORMER_START_DOCSTRING,
)
class LongformerModel(RobertaModel):
"""
This class overrides :class:`~transformers.RobertaModel` to provide the ability to process
long sequences following the selfattention approach described in `Longformer: the Long-Document Transformer`_by
Iz Beltagy, Matthew E. Peters, and Arman Cohan. Longformer selfattention combines a local (sliding window)
and global attention to extend to long documents without the O(n^2) increase in memory and compute.
The selfattention module `LongformerSelfAttention` implemented here supports the combination of local and
global attention but it lacks support for autoregressive attention and dilated attention. Autoregressive
and dilated attention are more relevant for autoregressive language modeling than finetuning on downstream
tasks. Future release will add support for autoregressive attention, but the support for dilated attention
requires a custom CUDA kernel to be memory and compute efficient.
.. _`Longformer: the Long-Document Transformer`:
https://arxiv.org/abs/2004.05150
"""
config_class = LongformerConfig
pretrained_model_archive_map = LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "longformer"
def __init__(self, config):
super().__init__(config)
if isinstance(config.attention_window, int):
assert config.attention_window % 2 == 0, "`attention_window` has to be an even value"
assert config.attention_window > 0, "`attention_window` has to be positive"
config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer
else:
assert len(config.attention_window) == config.num_hidden_layers, (
"`len(attention_window)` should equal `num_hidden_layers`. "
f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}"
)
if config.attention_mode == "bert":
pass # do nothing, use the default `modeling_bert.BertSelfAttention` (will OOM for long sequences)
elif config.attention_mode == "longformer":
for i, layer in enumerate(self.encoder.layer):
# replace the `modeling_bert.BertSelfAttention` object with `LongformerSelfAttention`
layer.attention.self = LongformerSelfAttention(config, layer_id=i)
else:
raise ValueError(
f'Expected values of `attention_mode` are "longformer" or "bert", given {config.attention_mode}'
)
self.init_weights()
def _pad_to_window_size(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor,
position_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
attention_window: int,
pad_token_id: int,
):
"""A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}"
input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape
batch_size, seqlen = input_shape[:2]
padding_len = (attention_window - seqlen % attention_window) % attention_window
if padding_len > 0:
logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
seqlen, seqlen + padding_len, attention_window
)
)
if input_ids is not None:
input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id)
if attention_mask is not None:
attention_mask = F.pad(
attention_mask, (0, padding_len), value=False
) # no attention on the padding tokens
if token_type_ids is not None:
token_type_ids = F.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0
if position_ids is not None:
# pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings
position_ids = F.pad(position_ids, (0, padding_len), value=pad_token_id)
if inputs_embeds is not None:
input_ids_padding = inputs_embeds.new_full(
(batch_size, padding_len), self.config.pad_token_id, dtype=torch.long,
)
inputs_embeds_padding = self.embeddings(input_ids_padding)
inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2)
return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
masked_lm_labels=None,
):
r"""
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
masked_lm_loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
import torch
from transformers import LongformerModel, LongformerTokenizer
model = LongformerModel.from_pretrained('longformer-base-4096')
tokenizer = LongformerTokenizer.from_pretrained('longformer-base-4096')
SAMPLE_TEXT = ' '.join(['Hello world! '] * 1000) # long input document
input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1
# Attention mask values -- 0: no attention, 1: local attention, 2: global attention
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention
attention_mask[:, [1, 4, 21,]] = 2 # Set global attention based on the task. For example,
# classification: the <s> token
# QA: question tokens
# LM: potentially on the beginning of sentences and paragraphs
sequence_output, pooled_output = model(input_ids, attention_mask=attention_mask)
"""
# padding
attention_window = (
self.config.attention_window
if isinstance(self.config.attention_window, int)
else max(self.config.attention_window)
)
padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
attention_window=attention_window,
pad_token_id=self.config.pad_token_id,
)
# embed
output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=None,
inputs_embeds=inputs_embeds,
encoder_hidden_states=None,
encoder_attention_mask=None,
)
# undo padding
if padding_len > 0:
# `output` has the following tensors: sequence_output, pooled_output, (hidden_states), (attentions)
# `sequence_output`: unpad because the calling function is expecting a length == input_ids.size(1)
# `pooled_output`: independent of the sequence length
# `hidden_states`: mainly used for debugging and analysis, so keep the padding
# `attentions`: mainly used for debugging and analysis, so keep the padding
output = output[0][:, :-padding_len], *output[1:]
return output
@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING)
class LongformerForMaskedLM(BertPreTrainedModel):
config_class = LongformerConfig
pretrained_model_archive_map = LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "longformer"
def __init__(self, config):
super().__init__(config)
self.longformer = LongformerModel(config)
self.lm_head = RobertaLMHead(config)
self.init_weights()
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
masked_lm_labels=None,
):
r"""
masked_lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
masked_lm_loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
import torch
from transformers import LongformerForMaskedLM, LongformerTokenizer
model = LongformerForMaskedLM.from_pretrained('longformer-base-4096')
tokenizer = LongformerTokenizer.from_pretrained('longformer-base-4096')
SAMPLE_TEXT = ' '.join(['Hello world! '] * 1000) # long input document
input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1
attention_mask = None # default is local attention everywhere, which is a good choice for MaskedLM
# check ``LongformerModel.forward`` for more details how to set `attention_mask`
loss, prediction_scores = model(input_ids, attention_mask=attention_mask, masked_lm_labels=input_ids)
"""
outputs = self.longformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
outputs = (masked_lm_loss,) + outputs
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
...@@ -29,6 +29,7 @@ from .configuration_auto import ( ...@@ -29,6 +29,7 @@ from .configuration_auto import (
ElectraConfig, ElectraConfig,
FlaubertConfig, FlaubertConfig,
GPT2Config, GPT2Config,
LongformerConfig,
OpenAIGPTConfig, OpenAIGPTConfig,
ReformerConfig, ReformerConfig,
RobertaConfig, RobertaConfig,
...@@ -50,6 +51,7 @@ from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFas ...@@ -50,6 +51,7 @@ from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFas
from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
from .tokenization_flaubert import FlaubertTokenizer from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_longformer import LongformerTokenizer
from .tokenization_marian import MarianTokenizer from .tokenization_marian import MarianTokenizer
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_reformer import ReformerTokenizer from .tokenization_reformer import ReformerTokenizer
...@@ -73,6 +75,7 @@ TOKENIZER_MAPPING = OrderedDict( ...@@ -73,6 +75,7 @@ TOKENIZER_MAPPING = OrderedDict(
(XLMRobertaConfig, (XLMRobertaTokenizer, None)), (XLMRobertaConfig, (XLMRobertaTokenizer, None)),
(MarianConfig, (MarianTokenizer, None)), (MarianConfig, (MarianTokenizer, None)),
(BartConfig, (BartTokenizer, None)), (BartConfig, (BartTokenizer, None)),
(LongformerConfig, (LongformerTokenizer, None)),
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)), (RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
(ReformerConfig, (ReformerTokenizer, None)), (ReformerConfig, (ReformerTokenizer, None)),
(ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)), (ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)),
...@@ -105,6 +108,7 @@ class AutoTokenizer: ...@@ -105,6 +108,7 @@ class AutoTokenizer:
- contains `albert`: AlbertTokenizer (ALBERT model) - contains `albert`: AlbertTokenizer (ALBERT model)
- contains `camembert`: CamembertTokenizer (CamemBERT model) - contains `camembert`: CamembertTokenizer (CamemBERT model)
- contains `xlm-roberta`: XLMRobertaTokenizer (XLM-RoBERTa model) - contains `xlm-roberta`: XLMRobertaTokenizer (XLM-RoBERTa model)
- contains `longformer`: LongformerTokenizer (AllenAI Longformer model)
- contains `roberta`: RobertaTokenizer (RoBERTa model) - contains `roberta`: RobertaTokenizer (RoBERTa model)
- contains `bert`: BertTokenizer (Bert model) - contains `bert`: BertTokenizer (Bert model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model) - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
...@@ -136,6 +140,7 @@ class AutoTokenizer: ...@@ -136,6 +140,7 @@ class AutoTokenizer:
- contains `albert`: AlbertTokenizer (ALBERT model) - contains `albert`: AlbertTokenizer (ALBERT model)
- contains `camembert`: CamembertTokenizer (CamemBERT model) - contains `camembert`: CamembertTokenizer (CamemBERT model)
- contains `xlm-roberta`: XLMRobertaTokenizer (XLM-RoBERTa model) - contains `xlm-roberta`: XLMRobertaTokenizer (XLM-RoBERTa model)
- contains `longformer`: LongformerTokenizer (AllenAI Longformer model)
- contains `roberta`: RobertaTokenizer (RoBERTa model) - contains `roberta`: RobertaTokenizer (RoBERTa model)
- contains `bert-base-japanese`: BertJapaneseTokenizer (Bert model) - contains `bert-base-japanese`: BertJapaneseTokenizer (Bert model)
- contains `bert`: BertTokenizer (Bert model) - contains `bert`: BertTokenizer (Bert model)
......
# coding=utf-8
# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from .tokenization_roberta import RobertaTokenizer
logger = logging.getLogger(__name__)
# vocab and merges same as roberta
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
_all_longformer_models = ["longformer-base-4096", "longformer-large-4096"]
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"longformer-base-4096": 4096,
"longformer-large-4096": 4096,
}
class LongformerTokenizer(RobertaTokenizer):
# merges and vocab same as Roberta
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_vocab_files_map = {
"vocab_file": {m: vocab_url for m in _all_longformer_models},
"merges_file": {m: merges_url for m in _all_longformer_models},
}
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():
import torch
from transformers import (
LongformerConfig,
LongformerModel,
LongformerForMaskedLM,
)
class LongformerModelTester(object):
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
attention_window=4,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.attention_window = attention_window
# `ModelTesterMixin.test_attention_outputs` is expecting attention tensors to be of size
# [num_attention_heads, encoder_seq_length, encoder_key_length], but LongformerSelfAttention
# returns attention of shape [num_attention_heads, encoder_seq_length, self.attention_window + 1]
# because its local attention only attends to `self.attention_window + 1` locations
self.key_length = self.attention_window + 1
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
# the `test_attention_outputs` and `test_hidden_states_output` tests
self.encoder_seq_length = (
self.seq_length + (self.attention_window - self.seq_length % self.attention_window) % self.attention_window
)
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = LongformerConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
attention_window=self.attention_window,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_longformer_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LongformerModel(config=config)
model.to(torch_device)
model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_longformer_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LongformerForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False # pruning is not supported
test_headmasking = False # head masking is not supported
test_torchscript = False
all_model_classes = (LongformerForMaskedLM, LongformerModel) if is_torch_available() else ()
def setUp(self):
self.model_tester = LongformerModelTester(self)
self.config_tester = ConfigTester(self, config_class=LongformerConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_longformer_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model(*config_and_inputs)
def test_longformer_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs)
class LongformerModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head(self):
model = LongformerModel.from_pretrained("longformer-base-4096")
# 'Hello world! ' repeated 1000 times
input_ids = torch.tensor([[0] + [20920, 232, 328, 1437] * 1000 + [2]]) # long input
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
attention_mask[:, [1, 4, 21]] = 2 # Set global attention on a few random positions
output = model(input_ids, attention_mask=attention_mask)[0]
expected_output_sum = torch.tensor(74585.8594)
expected_output_mean = torch.tensor(0.0243)
self.assertTrue(torch.allclose(output.sum(), expected_output_sum, atol=1e-4))
self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4))
@slow
def test_inference_masked_lm(self):
model = LongformerForMaskedLM.from_pretrained("longformer-base-4096")
# 'Hello world! ' repeated 1000 times
input_ids = torch.tensor([[0] + [20920, 232, 328, 1437] * 1000 + [2]]) # long input
loss, prediction_scores = model(input_ids, masked_lm_labels=input_ids)
expected_loss = torch.tensor(0.0620)
expected_prediction_scores_sum = torch.tensor(-6.1599e08)
expected_prediction_scores_mean = torch.tensor(-3.0622)
self.assertTrue(torch.allclose(loss, expected_loss, atol=1e-4))
self.assertTrue(torch.allclose(prediction_scores.sum(), expected_prediction_scores_sum, atol=1e-4))
self.assertTrue(torch.allclose(prediction_scores.mean(), expected_prediction_scores_mean, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment