"example/vscode:/vscode.git/clone" did not exist on "5512c5e94362acb667af34bba5c32e9817e0fa6e"
Unverified Commit dca34695 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Reformer (#3351)

* first copy & past commit from Bert and morgans LSH code

* add easy way to compare to trax original code

* translate most of function

* make trax lsh self attention deterministic with numpy seed + copy paste code

* add same config

* add same config

* make layer init work

* implemented hash_vectors function for lsh attention

* continue reformer translation

* hf LSHSelfAttentionLayer gives same output as trax layer

* refactor code

* refactor code

* refactor code

* refactor

* refactor + add reformer config

* delete bogus file

* split reformer attention layer into two layers

* save intermediate step

* save intermediate step

* make test work

* add complete reformer block layer

* finish reformer layer

* implement causal and self mask

* clean reformer test and refactor code

* fix merge conflicts

* fix merge conflicts

* update init

* fix device for GPU

* fix chunk length init for tests

* include morgans optimization

* improve memory a bit

* improve comment

* factorize num_buckets

* better testing parameters

* make whole model work

* make lm model work

* add t5 copy paste tokenizer

* add chunking feed forward

* clean config

* add improved assert statements

* make tokenizer work

* improve test

* correct typo

* extend config

* add complexer test

* add new axial position embeddings

* add local block attention layer

* clean tests

* refactor

* better testing

* save intermediate progress

* clean test file

* make shorter input length work for model

* allow variable input length

* refactor

* make forward pass for pretrained model work

* add generation possibility

* finish dropout and init

* make style

* refactor

* add first version of RevNet Layers

* make forward pass work and add convert file

* make uploaded model forward pass work

* make uploaded model forward pass work

* refactor code

* add namedtuples and cache buckets

* correct head masks

* refactor

* made reformer more flexible

* make style

* remove set max length

* add attention masks

* fix up tests

* fix lsh attention mask

* make random seed optional for the moment

* improve memory in reformer

* add tests

* make style

* make sure masks work correctly

* detach gradients

* save intermediate

* correct backprob through gather

* make style

* change back num hashes

* rename to labels

* fix rotation shape

* fix detach

* update

* fix trainer

* fix backward dropout

* make reformer more flexible

* fix conflict

* fix

* fix

* add tests for fixed seed in reformer layer

* fix trainer typo

* fix typo in activations

* add fp16 tests

* add fp16 training

* support fp16

* correct gradient bug in reformer

* add fast gelu

* re-add dropout for embedding dropout

* better naming

* better naming

* renaming

* finalize test branch

* finalize tests

* add more tests

* finish tests

* fix

* fix type trainer

* fix fp16 tests

* fix tests

* fix tests

* fix tests

* fix issue with dropout

* fix dropout seeds

* correct random seed on gpu

* finalize random seed for dropout

* finalize random seed for dropout

* remove duplicate line

* correct half precision bug

* make style

* refactor

* refactor

* docstring

* remove sinusoidal position encodings for reformer

* move chunking to modeling_utils

* make style

* clean config

* make style

* fix tests

* fix auto tests

* pretrained models

* fix docstring

* update conversion file

* Update pretrained_models.rst

* fix rst

* fix rst

* update copyright

* fix test path

* fix test path

* fix small issue in test

* include reformer in generation tests

* add docs for axial position encoding

* finish docs

* Update convert_reformer_trax_checkpoint_to_pytorch.py

* remove isort

* include sams comments

* remove wrong comment in utils

* correct typos

* fix typo

* Update reformer.rst

* applied morgans optimization

* make style

* make gpu compatible

* remove bogus file

* big test refactor

* add example for chunking

* fix typo

* add to README
parent 877fc564
......@@ -163,8 +163,9 @@ At some point in the future, you'll be able to seamlessly move from pre-training
16. **[BART](https://huggingface.co/transformers/model_doc/bart.html)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/pdf/1910.13461.pdf) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer.
17. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning.
18. **[DialoGPT](https://huggingface.co/transformers/model_doc/dialogpt.html)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
18. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users).
19. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
19. **[Reformer](https://huggingface.co/transformers/model_doc/reformer.html)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
20. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users).
21. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html).
......
......@@ -143,3 +143,14 @@ positional embeddings.
Absolute positional embeddings are selected in the range ``[0, config.max_position_embeddings - 1]``. Some models
use other types of positional embeddings, such as sinusoidal position embeddings or relative position embeddings.
Feed Forward Chunking
--------------------------
In transformers two feed forward layers usually follows the self attention layer in each residual attention block. The intermediate embedding size of the feed forward layers is often bigger than the hidden size of the model (*e.g.* for ``bert-base-uncased``).
For an input of size ``[batch_size, sequence_length]``, the memory required to store the intermediate feed forward embeddings ``[batch_size, sequence_length, config.intermediate_size]`` can account for a large fraction of the memory use. The authors of `Reformer: The Efficient Transformer <https://arxiv.org/abs/2001.04451>`_ noticed that since the computation is independent of the ``sequence_length`` dimension, it is mathematically equivalent to compute the output embeddings of both feed forward layers ``[batch_size, config.hidden_size]_0, ..., [batch_size, config.hidden_size]_n`` individually and concat them afterward to ``[batch_size, sequence_length, config.hidden_size]`` with ``n = sequence_length``, which trades increased computation time against reduced memory use, but yields a mathematically **equivalent** result.
For models employing the function :func:`~.transformers.apply_chunking_to_forward`, the ``chunk_size`` defines the number of output embeddings that are computed in parallel and thus defines the trade-off between memory and time complexity.
If ``chunk_size`` is set to 0, no feed forward chunking is done.
......@@ -107,3 +107,4 @@ The library currently contains PyTorch and Tensorflow implementations, pre-train
model_doc/t5
model_doc/electra
model_doc/dialogpt
model_doc/reformer
......@@ -14,6 +14,12 @@ The base class ``PreTrainedModel`` implements the common methods for loading/sav
.. autoclass:: transformers.PreTrainedModel
:members:
``Helper Functions``
~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: transformers.apply_chunking_to_forward
``TFPreTrainedModel``
~~~~~~~~~~~~~~~~~~~~~
......
Reformer
----------------------------------------------------
**DISCLAIMER:** This model is still a work in progress, if you see something strange,
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`_
Overview
~~~~~
The Reformer model was presented in `Reformer: The Efficient Transformer <https://https://arxiv.org/abs/2001.04451.pdf>`_ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
Here the abstract:
*Large Transformer models routinely achieve state-of-the-art results on a number of tasks but training these models can be prohibitively costly, especially on long sequences. We introduce two techniques to improve the efficiency of Transformers. For one, we replace dot-product attention by one that uses locality-sensitive hashing, changing its complexity from O(L^2) to O(Llog(L)), where L is the length of the sequence. Furthermore, we use reversible residual layers instead of the standard residuals, which allows storing activations only once in the training process instead of N times, where N is the number of layers. The resulting model, the Reformer, performs on par with Transformer models while being much more memory-efficient and much faster on long sequences.*
The Authors' code can be found `here <https://github.com/google/trax/tree/master/trax/models/reformer>`_ .
Axial Positional Encodings
~~~~~~~~~~~~~~~~~~~~
Axial Positional Encodings were first implemented in Google's `trax library <https://github.com/google/trax/blob/4d99ad4965bab1deba227539758d59f0df0fef48/trax/layers/research/position_encodings.py#L29>`_ and developed by the authors of this model's paper. In models that are treating very long input sequences, the conventional position id encodings store an embedings vector of size :math:`d` being the ``config.hidden_size`` for every position :math:`i, \ldots, n_s`, with :math:`n_s` being ``config.max_embedding_size``. *E.g.*, having a sequence length of :math:`n_s = 2^{19} \approx 0.5M` and a ``config.hidden_size`` of :math:`d = 2^{10} \approx 1000` would result in a position encoding matrix:
.. math::
X_{i,j}, \text{ with } i \in \left[1,\ldots, d\right] \text{ and } j \in \left[1,\ldots, n_s\right]
which alone has over 500M parameters to store. Axial positional encodings factorize :math:`X_{i,j}` into two matrices:
.. math::
X^{1}_{i,j}, \text{ with } i \in \left[1,\ldots, d^1\right] \text{ and } j \in \left[1,\ldots, n_s^1\right]
and
.. math::
X^{2}_{i,j}, \text{ with } i \in \left[1,\ldots, d^2\right] \text{ and } j \in \left[1,\ldots, n_s^2\right]
with:
.. math::
d = d^1 + d^2 \text{ and } n_s = n_s^1 \times n_s^2 .
Therefore the following holds:
.. math::
X_{i,j} = \begin{cases}
X^{1}_{i, k}, & \text{if }\ i < d^1 \text{ with } k = j \mod n_s^1 \\
X^{2}_{i - d^1, l}, & \text{if } i \ge d^1 \text{ with } l = \lfloor\frac{j}{n_s^1}\rfloor
\end{cases}
Intuitively, this means that a position embedding vector :math:`x_j \in \mathbb{R}^{d}` is now the composition of two factorized embedding vectors: :math:`x^1_{k, l} + x^2_{l, k}`, where as the ``config.max_embedding_size`` dimension :math:`j` is factorized into :math:`k \text{ and } l`.
This design ensures that each position embedding vector :math:`x_j` is unique.
Using the above example again, axial position encoding with :math:`d^1 = 2^5, d^2 = 2^5, n_s^1 = 2^9, n_s^2 = 2^{10}` can drastically reduced the number of parameters to :math:`2^{14} + 2^{15} \approx 49000` parameters.
In practice, the parameter ``config.axial_pos_embds_dim`` is set to ``list``:math:`(d^1, d^2)` which sum has to be equal to ``config.hidden_size`` and ``config.axial_pos_shape`` is set to ``list``:math:`(n_s^1, n_s^2)` and which product has to be equal to ``config.max_embedding_size`` which during training has to be equal to the ``sequence length`` of the ``input_ids``.
LSH Self Attention
~~~~~~~~~~~~~~~~~~~~
In Locality sensitive hashing (LSH) self attention the key and query projection weights are tied. Therefore, the key query embedding vectors are also tied.
LSH self attention uses the locality sensitive
hashing mechanism proposed in `Practical and Optimal LSH for Angular Distance <https://arxiv.org/abs/1509.02897>`_ to assign each of the tied key query embedding vectors to one of ``config.num_buckets`` possible buckets. The premise is that the more "similar" key query embedding vectors (in terms of *cosine similarity*) are to each other, the more likely they are assigned to the same bucket.
The accuracy of the LSH mechanism can be improved by increasing ``config.num_hashes`` or directly the argument ``num_hashes`` of the forward function so that the output of the LSH self attention better approximates the output of the "normal" full self attention.
The buckets are then sorted and chunked into query key embedding vector chunks each of length ``config.lsh_chunk_length``. For each chunk, the query embedding vectors attend to its key vectors (which are tied to themselves) and to the key embedding vectors of ``config.lsh_num_chunks_before`` previous neighboring chunks and ``config.lsh_num_chunks_after`` following neighboring chunks.
For more information, see the `original Paper <https://arxiv.org/abs/2001.04451>`_ or this great `blog post <https://www.pragmatic.ml/reformer-deep-dive/>`_.
Note that ``config.num_buckets`` can also be factorized into a ``list``:math:`(n_{\text{buckets}}^1, n_{\text{buckets}}^2)`. This way instead of assigning the query key embedding vectors to one of :math:`(1,\ldots, n_{\text{buckets}})` they are assigned to one of :math:`(1-1,\ldots, n_{\text{buckets}}^1-1, \ldots, 1-n_{\text{buckets}}^2, \ldots, n_{\text{buckets}}^1-n_{\text{buckets}}^2)`. This is crucial for very long sequences to save memory.
It is recommended to leave ``config.num_buckets=None``, so that depending on the sequence length, a good value for ``num_buckets`` are calculated on the fly.
Using LSH self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times \log(n_s))`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length.
Local Self Attention
~~~~~~~~~~~~~~~~~~~~
Local self attention is essentially a "normal" self attention layer with
key, query and value projections, but is chunked so that in each chunk of length ``config.local_chunk_length`` the query embedding vectors only attends to the key embedding vectors in its chunk and to the key embedding vectors of ``config.local_num_chunks_before`` previous neighboring chunks and ``config.local_num_chunks_after`` following neighboring chunks.
Using Local self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times \log(n_s))`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length.
Training
~~~~~~~~~~~~~~~~~~~~
During training, we must ensure that the sequence length is set to a value that can be divided by the least common multiple of ``config.lsh_chunk_length`` and ``config.local_chunk_length`` and that the parameters of the Axial Positional Encodings are correctly set as described above. Reformer is very memory efficient so that the model can easily be trained on sequences as long as 64000 tokens.
For training, the ``ReformerModelWithLMHead`` should be used as follows:
::
input_ids = tokenizer.encode('This is a sentence from the training data', return_tensors='pt')
loss = model(input_ids, labels=input_ids)[0]
ReformerConfig
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.ReformerConfig
:members:
ReformerTokenizer
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.ReformerTokenizer
:members:
ReformerModel
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.ReformerModel
:members:
ReformerModelWithLMHead
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.ReformerModelWithLMHead
:members:
......@@ -296,3 +296,6 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
| | ``DialoGPT-large`` | | 36-layer, 1280-hidden, 20-heads, 774M parameters |
| | | | Trained on English text: 147M conversation-like exchanges extracted from Reddit. |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Reformer | ``reformer-crime-and-punishment`` | | 6-layer, 256-hidden, 2-heads, 3M parameters |
| | | | Trained on English text: Crime and Punishment novel by Fyodor Dostoyevsky |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
......@@ -47,6 +47,7 @@ from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_marian import MarianConfig
from .configuration_mmbt import MMBTConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
......@@ -138,6 +139,7 @@ from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_reformer import ReformerTokenizer
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast
......@@ -159,7 +161,7 @@ if is_sklearn_available():
# Modeling
if is_torch_available():
from .modeling_utils import PreTrainedModel, prune_layer, Conv1D, top_k_top_p_filtering
from .modeling_utils import PreTrainedModel, prune_layer, Conv1D, top_k_top_p_filtering, apply_chunking_to_forward
from .modeling_auto import (
AutoModel,
AutoModelForPreTraining,
......@@ -190,6 +192,7 @@ if is_torch_available():
BertForQuestionAnswering,
load_tf_weights_in_bert,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BertLayer,
)
from .modeling_openai import (
OpenAIGPTPreTrainedModel,
......@@ -320,6 +323,14 @@ if is_torch_available():
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_reformer import (
ReformerAttention,
ReformerLayer,
ReformerModel,
ReformerModelWithLMHead,
REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
)
# Optimization
from .optimization import (
AdamW,
......
......@@ -34,12 +34,18 @@ if torch.__version__ < "1.4.0":
else:
gelu = F.gelu
def gelu_fast(x):
return 0.5 * x * (1 + torch.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x)))
ACT2FN = {
"relu": F.relu,
"swish": swish,
"gelu": gelu,
"tanh": torch.tanh,
"gelu_new": gelu_new,
"gelu_fast": gelu_fast,
}
......
......@@ -29,6 +29,7 @@ from .configuration_encoder_decoder import EncoderDecoderConfig
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_reformer import ReformerConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
......@@ -73,6 +74,7 @@ CONFIG_MAPPING = OrderedDict(
("camembert", CamembertConfig,),
("xlm-roberta", XLMRobertaConfig,),
("bart", BartConfig,),
("reformer", ReformerConfig,),
("roberta", RobertaConfig,),
("flaubert", FlaubertConfig,),
("bert", BertConfig,),
......@@ -130,6 +132,7 @@ class AutoConfig:
- contains `camembert`: :class:`~transformers.CamembertConfig` (CamemBERT model)
- contains `xlm-roberta`: :class:`~transformers.XLMRobertaConfig` (XLM-RoBERTa model)
- contains `roberta`: :class:`~transformers.RobertaConfig` (RoBERTa model)
- contains `reformer`: :class:`~transformers.ReformerConfig` (Reformer model)
- contains `bert`: :class:`~transformers.BertConfig` (Bert model)
- contains `openai-gpt`: :class:`~transformers.OpenAIGPTConfig` (OpenAI GPT model)
- contains `gpt2`: :class:`~transformers.GPT2Config` (OpenAI GPT-2 model)
......
# coding=utf-8
# Copyright 2020 The Trax Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Reformer model configuration """
import logging
from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/config.json"
}
class ReformerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.ReformerModel`.
It is used to instantiate an Reformer model according to the specified arguments, defining the model
architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
for more information.
Args:
attention_head_size (:obj:`int`, optional, defaults to 64):
Dimensionality of the projected key, query and value vectors
attn_layers (:obj:`list(str)`, optional, defaults to ["local", "lsh", "local", "lsh", "local", "lsh"]):
List of attention layer types in ascending order. It can be chosen between a
LSHSelfAttention layer ("lsh") and a LocalSelfAttention layer ("local").
For more information on LSHSelfAttention layer, see `LSH Self Attention <reformer.html#lsh-self-attention>`__ .
For more information on LocalSelfAttention layer, see `Local Self Attention <reformer.html#local-sensitive-hashing-self-attention>`__ .
axial_pos_embds (:obj:`bool`, optional, defaults to True):
If `True` use axial position embeddings. For more information on how axial position embeddings work, see `Axial Position Encodings <reformer.html#axial-positional-encodings>`__
axial_norm_std (:obj:`float`, optional, defaluts to 1.0):
The standard deviation of the normal_initializer for initializing the weight matrices of the axial positional encodings.
axial_pos_shape (:obj:`list(int)`, optional, defaults to `[64, 64]`):
The position dims of the axial position encodings.
During training the product of the position dims has to equal the sequence length.
For more information on how axial position embeddings work, see `Axial Position Encodings <reformer.html#axial-positional-encodings>`__ncodings.
axial_pos_embds_dim (:obj:`list(int)`, optional, defaults to `[64, 192]`):
The embedding dims of the axial position encodings.
The sum of the embedding dims has to equal the hidden size.
For more information on how axial position embeddings work, see `Axial Position Encodings <reformer.html#axial-positional-encodings>`__ncodings.
chunk_size_lm_head (:obj:`int`, optional, defaults to 0):
The chunk size of the final language model feed forward head layer.
A chunk size of 0 means that the feed forward layer is not chunked.
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
chunk_size_feed_forward (:obj:`int`, optional, defaults to 0):
The chunk size of all feed forward layers in the residual attention blocks.
A chunk size of 0 means that the feed forward layer is not chunked.
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
eos_token_id (:obj:`int`, optional, defaults to 2):
The token id for the <EOS> token.
feed_forward_size (:obj:`int`, optional, defaults to 512):
Dimensionality of the "feed_forward" (i.e., feed-forward) layer in the residual attention block.
hash_seed (:obj:`int`, optional, defaults to `None`):
Seed that can be used to make local sensitive hashing in LSHSelfAttention deterministic. This should only be set for testing purposed. For evaluation and training purposes `hash_seed` should be set to `None` to ensure fully random rotations in local sensitive hashing scheme.
hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "relu"):
The non-linear activation function (function or string) in the feed forward layer in the residual attention block.
If string, "gelu", "relu", "swish", "gelu_new" and "gelu_fast" are supported.
hidden_dropout_prob (:obj:`float`, optional, defaults to 0.05):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
hidden_size (:obj:`int`, optional, defaults to 256):
Dimensionality of the output hidden states of the residual attention blocks.
initializer_range (:obj:`float`, optional, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
is_decoder (:obj:`bool`, optional, defaults to False):
If `is_decoder` is True, a causal mask is used in addition to `attention_mask`.
When using the Reformer for causal language modeling, `is_decoder` is set to `True`.
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
The epsilon used by the layer normalization layers.
local_chunk_length (:obj:`int`, optional, defaults to 64):
Length of chunk which attends to itself in LocalSelfAttention. Chunking reduces memory complexity from sequence length x sequence length (self attention) to chunk length x chunk length x sequence length / chunk length (chunked self attention).
local_num_chunks_before (:obj:`int`, optional, defaults to 1):
Number of previous neighbouring chunks to attend to in LocalSelfAttention layer to itself.
local_num_chunks_after (:obj:`int`, optional, defaults to 0):
Number of following neighbouring chunks to attend to in LocalSelfAttention layer in addition to itself.
local_attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
The dropout ratio for the attention probabilities in LocalSelfAttention.
lsh_chunk_length (:obj:`int`, optional, defaults to 64):
Length of chunk which attends to itself in LSHSelfAttention. Chunking reduces memory complexity from sequence length x sequence length (self attention) to chunk length x chunk length x sequence length / chunk length (chunked self attention).
lsh_num_chunks_before (:obj:`int`, optional, defaults to 1):
Number of previous neighbouring chunks to attend to in LSHSelfAttention layer to itself.
lsh_num_chunks_after (:obj:`int`, optional, defaults to 0):
Number of following neighbouring chunks to attend to in LSHSelfAttention layer to itself.
lsh_attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
The dropout ratio for the attention probabilities in LSHSelfAttention.
max_position_embeddings (:obj:`int`, optional, defaults to 4096):
The maximum sequence length that this model might ever be used with.
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
num_attention_heads (:obj:`int`, optional, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_buckets (:obj:`int` or :obj:`list(int)`, optional, defaults to `64`):
Number of buckets, the key query vectors can be "hashed into" using the locality sensitive hashing scheme. Each query key vector is hashed into a hash in `1, ..., num_buckets`.
The number of buckets can also be factorized into a list for improved memory complexity. In this case, each query key vector is hashed into a hash in `1-1, 1-2, ..., num_buckets[0]-1, ..., num_buckets[0]-num_buckets[1]` if `num_buckets` is factorized into two factors.
The number of buckets (or the product the factors) should approximately equal sequence length / lsh_chunk_length.
num_hashes (:obj:`int`, optional, defaults to 1):
Number of hashing rounds (e.g. number of random rotations) in Local Sensitive Hashing scheme.
The higher `num_hashes`, the more accurate the `LSHSelfAttention` becomes, but also the more memory and time intensive the hashing becomes.
pad_token_id (:obj:`int`, optional, defaults to 0):
The token id for the <PAD> token.
vocab_size (:obj:`int`, optional, defaults to 320):
Vocabulary size of the Reformer model. Defines the different tokens that
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.ReformerModel`.
Example::
from transformers import ReformerModel, ReformerConfig
# Initializing a Reformer configuration
configuration = ReformerConfig()
# Initializing a Reformer model
model = ReformerModel(configuration)
# Accessing the model configuration
configuration = model.config
Attributes:
pretrained_config_archive_map (Dict[str, str]):
A dictionary containing all the available pre-trained checkpoints.
"""
pretrained_config_archive_map = REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "reformer"
def __init__(
self,
attention_head_size=64,
attn_layers=["local", "lsh", "local", "lsh", "local", "lsh"],
axial_norm_std=1.0,
axial_pos_embds=True,
axial_pos_shape=[64, 64],
axial_pos_embds_dim=[64, 192],
chunk_size_lm_head=0,
chunk_size_feed_forward=0,
eos_token_id=2,
feed_forward_size=512,
hash_seed=None,
hidden_act="relu",
hidden_dropout_prob=0.05,
hidden_size=256,
initializer_range=0.02,
is_decoder=False,
layer_norm_eps=1e-12,
local_num_chunks_before=1,
local_num_chunks_after=0,
local_attention_probs_dropout_prob=0.05,
local_attn_chunk_length=64,
lsh_attn_chunk_length=64,
lsh_attention_probs_dropout_prob=0.0,
lsh_num_chunks_before=1,
lsh_num_chunks_after=0,
max_position_embeddings=4096,
num_attention_heads=2,
num_buckets=32,
num_hashes=1,
pad_token_id=0,
vocab_size=320,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, eos_token_id=eos_token_id, is_decoder=is_decoder, **kwargs)
self.hash_seed = hash_seed
self.vocab_size = vocab_size
self.attention_head_size = attention_head_size
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_hashes = num_hashes
self.num_hidden_layers = len(attn_layers)
self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets
self.lsh_attn_chunk_length = lsh_attn_chunk_length
self.local_attn_chunk_length = local_attn_chunk_length
self.lsh_num_chunks_after = lsh_num_chunks_after
self.lsh_num_chunks_before = lsh_num_chunks_before
self.local_num_chunks_after = local_num_chunks_after
self.local_num_chunks_before = local_num_chunks_before
self.hidden_act = hidden_act
self.feed_forward_size = feed_forward_size
self.hidden_dropout_prob = hidden_dropout_prob
self.lsh_attention_probs_dropout_prob = lsh_attention_probs_dropout_prob
self.local_attention_probs_dropout_prob = local_attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.axial_pos_embds = axial_pos_embds
self.axial_pos_shape = tuple(axial_pos_shape)
self.axial_pos_embds_dim = tuple(axial_pos_embds_dim)
self.axial_norm_std = axial_norm_std
self.chunk_size_lm_head = chunk_size_lm_head
self.chunk_size_feed_forward = chunk_size_feed_forward
self.attn_layers = attn_layers
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Reformer checkpoint."""
import argparse
import logging
import pickle
import numpy as np
import torch
from transformers import ReformerConfig, ReformerModelWithLMHead
logging.basicConfig(level=logging.INFO)
def set_param(torch_layer, weight, bias=None):
# set parameter of one layer
assert torch_layer.weight.shape == weight.shape, "{} layer.weight does not match".format(torch_layer)
torch_layer.weight = torch.nn.Parameter(weight)
if bias is not None:
assert torch_layer.bias.shape == bias.shape, "{} layer.bias does not match".format(torch_layer)
torch_layer.bias = torch.nn.Parameter(bias)
def set_layer_weights_in_torch_lsh(weights, torch_layer, hidden_size):
# set torch weights for 1-to-1 comparison
np_query_key = np.asarray(weights[0])
np_value = np.asarray(weights[1])
np_dense = np.asarray(weights[2])
set_param(
torch_layer.self_attention.query_key,
torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size),
)
set_param(
torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size),
)
set_param(
torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1),
)
def set_layer_weights_in_torch_local(weights, torch_layer, hidden_size):
# set torch weights for 1-to-1 comparison
np_query = np.asarray(weights[0])
np_key = np.asarray(weights[1])
np_value = np.asarray(weights[2])
np_dense = np.asarray(weights[3])
set_param(
torch_layer.self_attention.query, torch.tensor(np_query).transpose(1, 2).contiguous().view(-1, hidden_size),
)
set_param(
torch_layer.self_attention.key, torch.tensor(np_key).transpose(1, 2).contiguous().view(-1, hidden_size),
)
set_param(
torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size),
)
set_param(
torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1),
)
def set_block_weights_in_torch(weights, torch_block, hidden_size):
# layernorm 1
layer_norm_1 = weights[0][0][0]
layer_norm_1_weight = np.asarray(layer_norm_1[0])
layer_norm_1_bias = np.asarray(layer_norm_1[1])
set_param(
torch_block.attention.layer_norm, torch.tensor(layer_norm_1_weight), torch.tensor(layer_norm_1_bias),
)
# lsh weights + output
attn_weights = weights[0][1]
if len(attn_weights) < 4:
set_layer_weights_in_torch_lsh(attn_weights, torch_block.attention, hidden_size)
else:
set_layer_weights_in_torch_local(attn_weights, torch_block.attention, hidden_size)
# intermediate weighs
intermediate_weights = weights[2][0][2][2]
# Chunked Feed Forward
if len(intermediate_weights) == 4:
intermediate_weights = intermediate_weights[2]
# layernorm 2
layer_norm_2_weight = np.asarray(intermediate_weights[0][0])
layer_norm_2_bias = np.asarray(intermediate_weights[0][1])
set_param(
torch_block.feed_forward.layer_norm, torch.tensor(layer_norm_2_weight), torch.tensor(layer_norm_2_bias),
)
# intermediate dense
inter_dense_weight = np.asarray(intermediate_weights[1][0])
inter_dense_bias = np.asarray(intermediate_weights[1][1])
set_param(
torch_block.feed_forward.dense.dense,
torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(),
torch.tensor(inter_dense_bias),
)
# intermediate out
out_dense_weight = np.asarray(intermediate_weights[4][0])
out_dense_bias = np.asarray(intermediate_weights[4][1])
set_param(
torch_block.feed_forward.output.dense,
torch.tensor(out_dense_weight).transpose(0, 1).contiguous(),
torch.tensor(out_dense_bias),
)
def set_model_weights_in_torch(weights, torch_model, hidden_size):
# reformer model
torch_model_reformer = torch_model.reformer
# word embeds
word_embeddings = np.asarray(weights[1])
set_param(
torch_model_reformer.embeddings.word_embeddings, torch.tensor(word_embeddings),
)
if isinstance(weights[3], tuple):
position_embeddings = torch_model_reformer.embeddings.position_embeddings
for emb_idx in range(len(position_embeddings.weights)):
emb_weights = np.asarray(weights[3][emb_idx][0])
assert position_embeddings.weights[emb_idx].shape == emb_weights.shape, "{} emb does not match".format(
position_embeddings[emb_idx]
)
position_embeddings.weights[emb_idx] = torch.nn.Parameter(torch.tensor(emb_weights))
trax_layer_weights = weights[5]
assert len(torch_model_reformer.encoder.layers) * 4 + 1 == len(
trax_layer_weights
), "HF and trax model do not have the same number of layers"
for layer_idx, layer in enumerate(torch_model_reformer.encoder.layers):
block_weights = trax_layer_weights[4 * layer_idx : 4 * (layer_idx + 1)]
set_block_weights_in_torch(block_weights, layer, hidden_size)
# output weights
out_weights = weights[6]
# output layer norm
layer_norm_out_weight = np.asarray(out_weights[0][0])
layer_norm_out_bias = np.asarray(out_weights[0][1])
set_param(
torch_model_reformer.encoder.layer_norm,
torch.tensor(layer_norm_out_weight),
torch.tensor(layer_norm_out_bias),
)
# output embeddings
output_embed_weights = np.asarray(out_weights[2][0])
output_embed_bias = np.asarray(out_weights[2][1])
set_param(
torch_model.lm_head.decoder,
torch.tensor(output_embed_weights).transpose(0, 1).contiguous(),
torch.tensor(output_embed_bias),
)
def convert_trax_checkpoint_to_pytorch(trax_model_pkl_path, config_file, pytorch_dump_path):
# Initialise PyTorch model
config = ReformerConfig.from_json_file(config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = ReformerModelWithLMHead(config)
with open(trax_model_pkl_path, "rb") as f:
model_weights = pickle.load(f)["weights"]
set_model_weights_in_torch(model_weights, model, config.hidden_size)
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--trax_model_pkl_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained Reformer model. \n"
"This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_trax_checkpoint_to_pytorch(args.trax_model_pkl_path, args.config_file, args.pytorch_dump_path)
......@@ -31,6 +31,7 @@ from .configuration_auto import (
FlaubertConfig,
GPT2Config,
OpenAIGPTConfig,
ReformerConfig,
RobertaConfig,
T5Config,
TransfoXLConfig,
......@@ -97,6 +98,7 @@ from .modeling_flaubert import (
)
from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Model
from .modeling_openai import OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTModel
from .modeling_reformer import ReformerModel, ReformerModelWithLMHead
from .modeling_roberta import (
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM,
......@@ -179,6 +181,7 @@ MODEL_MAPPING = OrderedDict(
(XLMConfig, XLMModel),
(CTRLConfig, CTRLModel),
(ElectraConfig, ElectraModel),
(ReformerConfig, ReformerModel),
]
)
......@@ -222,6 +225,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(CTRLConfig, CTRLLMHeadModel),
(ElectraConfig, ElectraForMaskedLM),
(EncoderDecoderConfig, EncoderDecoderModel),
(ReformerConfig, ReformerModelWithLMHead),
]
)
......
# coding=utf-8
# Copyright 2020 The Trax Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch REFORMER model. """
import logging
import sys
from collections import namedtuple
from functools import reduce
from operator import mul
import numpy as np
import torch
from torch import nn
from torch.autograd.function import Function
from torch.nn import CrossEntropyLoss
from .activations import gelu, gelu_fast, gelu_new, swish
from .configuration_reformer import ReformerConfig
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, apply_chunking_to_forward
logger = logging.getLogger(__name__)
REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {
"google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/pytorch_model.bin"
}
def mish(x):
return x * torch.tanh(nn.functional.softplus(x))
ACT2FN = {
"gelu": gelu,
"relu": torch.nn.functional.relu,
"swish": swish,
"gelu_new": gelu_new,
"gelu_fast": gelu_fast,
"mish": mish,
}
# Define named tuples for nn.Modules here
LSHSelfAttentionOutput = namedtuple("LSHSelfAttentionOutput", ["hidden_states", "attention_probs", "buckets"])
LocalSelfAttentionOutput = namedtuple("LocalSelfAttentionOutput", ["hidden_states", "attention_probs"])
AttentionOutput = namedtuple("AttentionOutput", ["hidden_states", "attention_probs", "buckets"])
ReformerOutput = namedtuple("ReformerOutput", ["hidden_states", "attn_output", "attention_probs", "buckets"])
ReformerBackwardOutput = namedtuple(
"ReformerBackwardOutput", ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"]
)
ReformerEncoderOutput = namedtuple("ReformerEncoderOutput", ["hidden_states", "all_hidden_states", "all_attentions"])
def _get_least_common_mult_chunk_len(config):
attn_types = config.attn_layers
attn_types_set = set(attn_types)
if len(attn_types_set) == 1 and attn_types[0] == "lsh":
return config.lsh_attn_chunk_length
elif len(attn_types_set) == 1 and attn_types[0] == "local":
return config.local_attn_chunk_length
elif len(attn_types_set) == 2 and attn_types_set == set(["lsh", "local"]):
return np.lcm(config.lsh_attn_chunk_length, config.local_attn_chunk_length)
else:
raise NotImplementedError(
"Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {}. Select attn layer types from ['lsh', 'local'] only.".format(
config.attn_layers
)
)
class AxialPositionEmbeddings(nn.Module):
"""Constructs axial position embeddings. Useful for very long input
sequences to save memory and time.
"""
def __init__(self, config):
super().__init__()
self.axial_pos_shape = config.axial_pos_shape
self.axial_pos_embds_dim = config.axial_pos_embds_dim
self.dropout = config.hidden_dropout_prob
self.least_common_mult_chunk_length = _get_least_common_mult_chunk_len(config)
self.weights = nn.ParameterList()
assert (
sum(self.axial_pos_embds_dim) == config.hidden_size
), "Make sure that config.axial_pos_embds factors: {} sum to config.hidden_size: {}".format(
self.axial_pos_embds_dim, config.hidden_size
)
# create weights
for axis, axial_pos_embd_dim in enumerate(self.axial_pos_embds_dim):
# create expanded shapes
ax_shape = [1] * len(self.axial_pos_shape)
ax_shape[axis] = self.axial_pos_shape[axis]
ax_shape = tuple(ax_shape) + (axial_pos_embd_dim,)
# create tensor and init
self.weights.append(nn.Parameter(torch.ones(ax_shape, dtype=torch.float32)))
def forward(self, position_ids):
# broadcast weights to correct shape
batch_size = position_ids.shape[0]
sequence_length = position_ids.shape[1]
broadcasted_weights = [
weight.expand((batch_size,) + self.axial_pos_shape + weight.shape[-1:]) for weight in self.weights
]
if self.training is True:
assert (
reduce(mul, self.axial_pos_shape) == sequence_length
), "Make sure that config.axial_pos_shape factors: {} multiply to sequence length: {}".format(
self.axial_pos_shape, sequence_length
)
if self.dropout > 0:
weights = torch.cat(broadcasted_weights, dim=-1)
# permute weights so that 2D correctly drops dims 1 and 2
transposed_weights = weights.transpose(2, 1)
# drop entire matrix of last two dims (prev dims 1 and 2)
dropped_transposed_weights = nn.functional.dropout2d(
transposed_weights, p=self.dropout, training=self.training
)
dropped_weights = dropped_transposed_weights.transpose(2, 1)
position_encodings = torch.reshape(dropped_weights, (batch_size, sequence_length, -1))
else:
position_encodings = torch.cat(
[torch.reshape(weight, (batch_size, sequence_length, -1)) for weight in broadcasted_weights],
dim=-1,
)
else:
assert (
reduce(mul, self.axial_pos_shape) >= sequence_length
), "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, least_common_mult_chunk_length): max({}, {})".format(
self.axial_pos_shape, sequence_length, self.least_common_mult_chunk_length,
)
# reshape axial encodings and use only until sequence_length
position_encodings = torch.cat(broadcasted_weights, dim=-1)
position_encodings = position_encodings.view(batch_size, -1, position_encodings.shape[-1])[
:, :sequence_length
]
return position_encodings
class PositionEmbeddings(nn.Module):
"""Constructs conventional position embeddings of shape `[max_pos_embeddings, hidden_size]`.
"""
def __init__(self, config):
super().__init__()
self.dropout = config.hidden_dropout_prob
self.embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
def forward(self, position_ids):
position_embeddings = self.embedding(position_ids)
position_embeddings = nn.functional.dropout(position_embeddings, p=self.dropout, training=self.training)
return position_embeddings
class ReformerEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super().__init__()
self.max_position_embeddings = config.max_position_embeddings
self.dropout = config.hidden_dropout_prob
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = (
AxialPositionEmbeddings(config) if config.axial_pos_embds else PositionEmbeddings(config)
)
def forward(self, input_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
device = input_ids.device
else:
input_shape = inputs_embeds.size()[:-1]
device = inputs_embeds.device
seq_length = input_shape[1]
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
assert (
position_ids.shape[-1] <= self.max_position_embeddings
), "Sequence Length: {} has to be larger equal than config.max_position_embeddings: {}".format(
position_ids.shape[-1], self.max_position_embeddings
)
# dropout
embeddings = nn.functional.dropout(inputs_embeds, p=self.dropout, training=self.training)
# add positional embeddings
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
return embeddings
class EfficientAttentionMixin:
"""
A few utilities for nn.Modules in Reformer, to be used as a mixin.
"""
def _look_adjacent(self, vectors, num_chunks_before, num_chunks_after):
""" Used to implement attention between consecutive chunks.
Args:
vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...]
num_chunks_before: chunks before current chunk to include in attention
num_chunks_after: chunks after current chunk to include in attention
Returns:
tensor of shape [num_chunks, N * chunk_length, ...], where
N = (1 + num_chunks_before + num_chunks_after).
"""
if num_chunks_before == 0 and num_chunks_after == 0:
return vectors
slices = []
for i in range(-num_chunks_before, num_chunks_after + 1):
if i == 0:
slices.append(vectors)
else:
slices.append(torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2))
return torch.cat(slices, dim=3)
def _split_hidden_size_dim(self, x, num_attn_heads, attn_head_size):
"""
splits hidden_size dim into attn_head_size and num_attn_heads
"""
new_x_shape = x.size()[:-1] + (num_attn_heads, attn_head_size)
x = x.view(*new_x_shape)
return x.transpose(2, 1)
def _merge_hidden_size_dims(self, x, num_attn_heads, attn_head_size):
"""
merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
x = x.permute(0, 2, 1, 3)
return torch.reshape(x, (x.size()[0], -1, num_attn_heads * attn_head_size))
def _split_seq_length_dim_to(self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, attn_head_size=None):
"""
splits sequence length dim of vectors into `dim_factor_1` and `dim_factor_2` dims
"""
batch_size = vectors.shape[0]
split_dim_shape = (batch_size, num_attn_heads, dim_factor_1, dim_factor_2)
if len(vectors.shape) == 4:
return torch.reshape(vectors, split_dim_shape + (attn_head_size,))
elif len(vectors.shape) == 3:
return torch.reshape(vectors, split_dim_shape)
else:
raise ValueError("Input vector rank should be one of [3, 4], but is: {}".format(len(vectors.shape)))
class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
def __init__(self, config):
super().__init__()
self.chunk_length = config.lsh_attn_chunk_length
self.num_hashes = config.num_hashes
self.num_buckets = config.num_buckets
self.num_chunks_before = config.lsh_num_chunks_before
self.num_chunks_after = config.lsh_num_chunks_after
self.hash_seed = config.hash_seed
self.is_decoder = config.is_decoder
self.max_position_embeddings = config.max_position_embeddings
self.dropout = config.lsh_attention_probs_dropout_prob
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = config.attention_head_size
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.hidden_size = config.hidden_size
# projection matrices
self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False)
self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False)
# save mask value here. Need fp32 and fp16 mask values
self.register_buffer("self_mask_value_float16", torch.tensor(-1e3))
self.register_buffer("self_mask_value_float32", torch.tensor(-1e5))
self.register_buffer("mask_value_float16", torch.tensor(-1e4))
self.register_buffer("mask_value_float32", torch.tensor(-1e9))
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
num_hashes=None,
do_output_attentions=False,
buckets=None,
**kwargs
):
sequence_length = hidden_states.shape[1]
batch_size = hidden_states.shape[0]
# num hashes can optionally be overwritten by user
num_hashes = num_hashes if num_hashes is not None else self.num_hashes
# project hidden_states to query_key and value
query_key_vectors = self.query_key(hidden_states)
value_vectors = self.value(hidden_states)
# free memory
del hidden_states
query_key_vectors = self._split_hidden_size_dim(
query_key_vectors, self.num_attention_heads, self.attention_head_size
)
value_vectors = self._split_hidden_size_dim(value_vectors, self.num_attention_heads, self.attention_head_size)
assert (
query_key_vectors.shape[-1] == self.attention_head_size
), "last dim of query_key_vectors is {} but should be {}.".format(
query_key_vectors.shape[-1], self.attention_head_size
)
assert (
value_vectors.shape[-1] == self.attention_head_size
), "last dim of value_vectors is {} but should be {}.".format(
value_vectors.shape[-1], self.attention_head_size
)
# set `num_buckets` on the fly, recommended way to do it
if self.num_buckets is None:
self._set_num_buckets(sequence_length)
# use cached buckets for backprop only
if buckets is None:
# hash query key vectors into buckets
buckets = self._hash_vectors(query_key_vectors, num_hashes)
assert (
int(buckets.shape[-1]) == num_hashes * sequence_length
), "last dim of buckets is {}, but should be {}".format(buckets.shape[-1], num_hashes * sequence_length)
sorted_bucket_idx, undo_sorted_bucket_idx = self._get_sorted_bucket_idx_and_undo_sorted_bucket_idx(
sequence_length, buckets, num_hashes
)
# make sure bucket idx is not longer then sequence length
sorted_bucket_idx = sorted_bucket_idx % sequence_length
# cluster query key value vectors according to hashed buckets
query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx, num_hashes)
value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx, num_hashes)
query_key_vectors = self._split_seq_length_dim_to(
query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
value_vectors = self._split_seq_length_dim_to(
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
if self.chunk_length is None:
assert (
self.num_chunks_before == 0 and self.num_chunks_after == 0
), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
# scale key vectors
key_vectors = self._len_and_dim_norm(query_key_vectors)
# get attention probs
out_vectors, logits, attention_probs = self._attend(
query_vectors=query_key_vectors,
key_vectors=key_vectors,
value_vectors=value_vectors,
sorted_bucket_idx=sorted_bucket_idx,
attention_mask=attention_mask,
head_mask=head_mask,
)
# free memory
del query_key_vectors, key_vectors, value_vectors
# sort clusters back to correct ordering
out_vectors, logits = ReverseSort.apply(
out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes
)
# sum up all hash rounds
if num_hashes > 1:
out_vectors = self._split_seq_length_dim_to(
out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size,
)
logits = self._split_seq_length_dim_to(
logits, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size,
).unsqueeze(-1)
probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True))
out_vectors = torch.sum(out_vectors * probs_vectors, dim=2)
# free memory
del probs_vectors
# free memory
del logits
assert out_vectors.shape == (
batch_size,
self.num_attention_heads,
sequence_length,
self.attention_head_size,
), "out_vectors have be of shape `[batch_size, config.num_attention_heads, sequence_length, config.attention_head_size]`."
out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size)
if do_output_attentions is False:
attention_probs = ()
return LSHSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets)
def _hash_vectors(self, vectors, num_hashes):
batch_size = vectors.shape[0]
# See https://arxiv.org/pdf/1509.02897.pdf
# We sample a different random rotation for each round of hashing to
# decrease the probability of hash misses.
if isinstance(self.num_buckets, int):
assert (
self.num_buckets % 2 == 0
), "There should be an even number of bucktes, but `self.num_bucktes`: {}".format(self.num_buckets)
rotation_size = self.num_buckets
num_buckets = self.num_buckets
else:
# Factorize the hash if self.num_buckets is a list or tuple
rotation_size, num_buckets = 0, 1
for bucket_factor in self.num_buckets:
assert bucket_factor % 2 == 0, "The number of buckets should be even, but `num_bucket`: {}".format(
bucket_factor
)
rotation_size = rotation_size + bucket_factor
num_buckets = num_buckets * bucket_factor
# remove gradient
vectors = vectors.detach()
if self.hash_seed is not None:
# for determinism
torch.manual_seed(self.hash_seed)
rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2)
# create a random self.attention_head_size x num_hashes x num_buckets/2
random_rotations = torch.randn(rotations_shape, device=vectors.device, dtype=vectors.dtype)
# Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2
rotated_vectors = torch.einsum("bmtd,mdhr->bmhtr", vectors, random_rotations)
if isinstance(self.num_buckets, int) or len(self.num_buckets) == 1:
rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1)
buckets = torch.argmax(rotated_vectors, dim=-1)
else:
# Get the buckets for them and combine.
buckets, cur_sum, cur_product = None, 0, 1
for bucket_factor in self.num_buckets:
rotated_vectors_factor = rotated_vectors[..., cur_sum : cur_sum + (bucket_factor // 2)]
cur_sum = cur_sum + bucket_factor // 2
rotated_vectors_factor = torch.cat([rotated_vectors_factor, -rotated_vectors_factor], dim=-1)
if buckets is None:
buckets = torch.argmax(rotated_vectors_factor, dim=-1)
else:
buckets = buckets + (cur_product * torch.argmax(rotated_vectors_factor, dim=-1))
cur_product = cur_product * bucket_factor
# buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len).
# Next we add offsets so that bucket numbers from different hashing rounds don't overlap.
offsets = torch.arange(num_hashes, device=vectors.device)
offsets = (offsets * num_buckets).view((1, 1, -1, 1))
# expand to batch size and num attention heads
offsets = offsets.expand((batch_size, self.num_attention_heads) + offsets.shape[-2:])
offset_buckets = (buckets + offsets).flatten(start_dim=2, end_dim=3)
return offset_buckets
def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(self, sequence_length, buckets, num_hashes):
# no gradients are needed
with torch.no_grad():
batch_size = buckets.shape[0]
# arange and expand
orig_indices = torch.arange(num_hashes * sequence_length, device=buckets.device).view(1, 1, -1)
orig_indices = orig_indices.expand(batch_size, self.num_attention_heads, orig_indices.shape[-1])
# scale buckets
scaled_buckets = sequence_length * buckets + (orig_indices % sequence_length)
# remove gradient
scaled_buckets = scaled_buckets.detach()
# Hash-based sort
sorted_bucket_idx = torch.argsort(scaled_buckets, dim=-1)
# create simple indices to scatter to, to have undo sort
indices = (
torch.arange(sorted_bucket_idx.shape[-1], device=buckets.device)
.view(1, 1, -1)
.expand(sorted_bucket_idx.shape)
)
# get undo sort
undo_sorted_bucket_idx = sorted_bucket_idx.new(*sorted_bucket_idx.size())
undo_sorted_bucket_idx.scatter_(-1, sorted_bucket_idx, indices)
return sorted_bucket_idx, undo_sorted_bucket_idx
def _set_num_buckets(self, sequence_length):
# recommended `num_buckets` from paper
num_buckets = 2 * sequence_length // self.chunk_length
# factorize `num_buckets` if `num_buckets` becomes too large
num_buckets_limit = max(int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length,)
if num_buckets > 2 * num_buckets_limit:
num_buckets = [num_buckets_limit, num_buckets // num_buckets_limit + 1]
logger.warning("config.num_buckets is not set. Setting config.num_buckets to {}...".format(num_buckets))
self.num_buckets = num_buckets
def _attend(
self, query_vectors, key_vectors, value_vectors, sorted_bucket_idx, attention_mask, head_mask,
):
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after)
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after)
# get logits and dots
query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2))
# free memory
del query_vectors, key_vectors
query_bucket_idx = self._split_seq_length_dim_to(
sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads
)
key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after)
# get correct mask values depending on precision
if query_key_dots.dtype == torch.float16:
self_mask_value = self.self_mask_value_float16
mask_value = self.mask_value_float16
else:
self_mask_value = self.self_mask_value_float32
mask_value = self.mask_value_float32
mask = self._compute_attn_mask(query_bucket_idx, key_value_bucket_idx, attention_mask)
if mask is not None:
query_key_dots = torch.where(mask, query_key_dots, mask_value)
# free memory
del mask
# Self mask is ALWAYS applied.
# From the reformer paper (https://arxiv.org/pdf/2001.04451.pdf):
# " While attention to the future is not allowed, typical implementations of the
# Transformer do allow a position to attend to itself.
# Such behavior is undesirable in a shared-QK formulation because the dot-product
# of a query vector with itself will almost always be greater than the dot product of a
# query vector with a vector at another position. We therefore modify the masking
# to forbid a token from attending to itself, except in situations
# where a token has no other valid attention targets (e.g. the first token in a sequence) "
self_mask = torch.ne(query_bucket_idx.unsqueeze(-1), key_value_bucket_idx.unsqueeze(-2)).to(
query_bucket_idx.device
)
# apply self_mask
query_key_dots = torch.where(self_mask, query_key_dots, self_mask_value)
# free memory
del self_mask
logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True)
# dots shape is `[batch_size, num_attn_heads, num_hashes * seq_len // chunk_length, chunk_length, chunk_length * (1 + num_chunks_before + num_chunks_after)]`
attention_probs = torch.exp(query_key_dots - logits)
# free memory
del query_key_dots
# dropout
attention_probs = nn.functional.dropout(attention_probs, p=self.dropout, training=self.training)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
# attend values
out_vectors = torch.matmul(attention_probs, value_vectors)
# free memory
del value_vectors
# merge chunk length
logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1)
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
return out_vectors, logits, attention_probs
def _compute_attn_mask(self, query_indices, key_indices, attention_mask):
mask = None
# Causal mask
if self.is_decoder:
mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device)
# Attention mask: chunk, look up correct mask value from key_value_bucket_idx
# IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why.
if attention_mask is not None:
attention_mask = attention_mask.to(torch.uint8)[:, None, None, :]
# expand attn_mask to fit with key_value_bucket_idx shape
attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,))
key_attn_mask = torch.gather(attention_mask, -1, key_indices)
query_attn_mask = torch.gather(attention_mask, -1, query_indices)
# expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk
attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2)
# free memory
del query_attn_mask, key_attn_mask, attention_mask
# multiply by casaul mask if necessary
if mask is not None:
mask = mask * attn_mask
else:
mask = attn_mask
return mask
def _len_and_dim_norm(self, vectors):
"""
length and attention head size dim normalization
"""
vectors = self._len_norm(vectors)
vectors = vectors * torch.rsqrt(
torch.tensor(self.attention_head_size, device=vectors.device, dtype=vectors.dtype)
)
return vectors
def _len_norm(self, x, epsilon=1e-6):
"""
length normalization
"""
variance = torch.mean(x ** 2, -1, keepdim=True)
norm_x = x * torch.rsqrt(variance + epsilon)
return norm_x
def _gather_by_expansion(self, vectors, idxs, num_hashes):
"""
expand dims of idxs and vectors for all hashes and gather
"""
expanded_idxs = idxs.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size)
vectors = vectors.repeat(1, 1, num_hashes, 1)
return torch.gather(vectors, 2, expanded_idxs)
class ReverseSort(Function):
"""
After chunked attention is applied which sorted clusters,
original ordering has to be restored.
Since customized backward function is used for Reformer,
the gradients of the output vectors have to be explicitely
sorted here.
"""
@staticmethod
def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, num_hashes):
# save sorted_bucket_idx for backprop
with torch.no_grad():
ctx.sorted_bucket_idx = sorted_bucket_idx
ctx.num_hashes = num_hashes
# undo sort to have correct order for next layer
expanded_undo_sort_indices = undo_sorted_bucket_idx.unsqueeze(-1).expand(out_vectors.shape)
out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices)
logits = torch.gather(logits, 2, undo_sorted_bucket_idx)
return out_vectors, logits
@staticmethod
def backward(ctx, grad_out_vectors, grad_logits):
# get parameters saved in ctx
sorted_bucket_idx = ctx.sorted_bucket_idx
num_hashes = ctx.num_hashes
# get real gradient shape
# shape is BatchSize x NumAttnHeads x ChunkLen * NumHashes
grad_logits_shape = grad_logits.shape
# shape is BatchSize x NumAttnHeads x ChunkLen * NumHashes x ChunkLen
grad_out_vectors_shape = grad_out_vectors.shape
# split gradient vectors and sorted bucket idxs by concatenated chunk dimension to gather correct indices
# shape is BatchSize x NumAttnHeads x NumHashes x ChunkLen
grad_logits = grad_logits.view((grad_logits_shape[:2] + (num_hashes, -1)))
# shape is BatchSize x NumAttnHeads x NumHashes x ChunkLen x ChunkLen
grad_out_vectors = grad_out_vectors.view(
(grad_out_vectors_shape[:2] + (num_hashes, -1) + grad_out_vectors_shape[-1:])
)
# reshape and expand
sorted_bucket_idx = torch.reshape(sorted_bucket_idx, (sorted_bucket_idx.shape[:2] + (num_hashes, -1)))
expanded_sort_indices = sorted_bucket_idx.unsqueeze(-1).expand(grad_out_vectors.shape)
# reverse sort of forward
grad_out_vectors = torch.gather(grad_out_vectors, 3, expanded_sort_indices)
grad_logits = torch.gather(grad_logits, 3, sorted_bucket_idx)
# reshape into correct shape
grad_logits = torch.reshape(grad_logits, grad_logits_shape)
grad_out_vectors = torch.reshape(grad_out_vectors, grad_out_vectors_shape)
# return grad and `None` fillers for last 3 forward args
return grad_out_vectors, grad_logits, None, None, None
class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
def __init__(self, config):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.chunk_length = config.local_attn_chunk_length
self.num_chunks_before = config.local_num_chunks_before
self.num_chunks_after = config.local_num_chunks_after
self.is_decoder = config.is_decoder
self.pad_token_id = config.pad_token_id
self.attention_head_size = config.attention_head_size
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.hidden_size = config.hidden_size
# projection matrices
self.query = nn.Linear(self.hidden_size, self.all_head_size, bias=False)
self.key = nn.Linear(self.hidden_size, self.all_head_size, bias=False)
self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False)
self.dropout = config.local_attention_probs_dropout_prob
# save mask value here
self.register_buffer("mask_value_float16", torch.tensor(-1e4))
self.register_buffer("mask_value_float32", torch.tensor(-1e9))
def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_attentions=False, **kwargs):
sequence_length = hidden_states.shape[1]
batch_size = hidden_states.shape[0]
# project hidden_states to query, key and value
query_vectors = self.query(hidden_states)
key_vectors = self.key(hidden_states)
value_vectors = self.value(hidden_states)
# split last dim into `config.num_attention_heads` and `config.attention_head_size`
query_vectors = self._split_hidden_size_dim(query_vectors, self.num_attention_heads, self.attention_head_size)
key_vectors = self._split_hidden_size_dim(key_vectors, self.num_attention_heads, self.attention_head_size)
value_vectors = self._split_hidden_size_dim(value_vectors, self.num_attention_heads, self.attention_head_size)
assert (
query_vectors.shape[-1] == self.attention_head_size
), "last dim of query_key_vectors is {} but should be {}.".format(
query_vectors.shape[-1], self.attention_head_size
)
assert (
key_vectors.shape[-1] == self.attention_head_size
), "last dim of query_key_vectors is {} but should be {}.".format(
key_vectors.shape[-1], self.attention_head_size
)
assert (
value_vectors.shape[-1] == self.attention_head_size
), "last dim of query_key_vectors is {} but should be {}.".format(
value_vectors.shape[-1], self.attention_head_size
)
if self.chunk_length is None:
assert (
self.num_chunks_before == 0 and self.num_chunks_after == 0
), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
# normalize key vectors
key_vectors = key_vectors / torch.sqrt(
torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=key_vectors.dtype)
)
# chunk vectors
# B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size
query_vectors = self._split_seq_length_dim_to(
query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
key_vectors = self._split_seq_length_dim_to(
key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
value_vectors = self._split_seq_length_dim_to(
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
# chunk indices
indices = torch.arange(sequence_length, device=query_vectors.device).repeat(
batch_size, self.num_attention_heads, 1
)
query_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads)
key_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads)
# append chunks before and after
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after)
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after)
key_indices = self._look_adjacent(key_indices, self.num_chunks_before, self.num_chunks_after)
query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2))
# free memory
del query_vectors, key_vectors
mask = self._compute_attn_mask(query_indices, key_indices, attention_mask, query_key_dots.shape)
if mask is not None:
# get mask tensor depending on half precision or not
if query_key_dots.dtype == torch.float16:
mask_value = self.mask_value_float16
else:
mask_value = self.mask_value_float32
query_key_dots = torch.where(mask, query_key_dots, mask_value)
# free memory
del mask
# softmax
logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True)
attention_probs = torch.exp(query_key_dots - logits)
# free memory
del logits
# dropout
attention_probs = nn.functional.dropout(attention_probs, p=self.dropout, training=self.training)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
# attend values
out_vectors = torch.matmul(attention_probs, value_vectors)
# free memory
del value_vectors
# merge chunk length
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,)
out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size)
if do_output_attentions is False:
attention_probs = ()
return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs)
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape):
mask = None
# chunk attention mask and look before and after
if attention_mask is not None:
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after)
# Causal mask
if self.is_decoder is True:
mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device)
# Attention mask
if attention_mask is not None:
# create attn_mask
attn_mask = (attention_mask.unsqueeze(-1) * attention_mask_key.unsqueeze(-2)).expand(query_key_dots_shape)
# multiply by casaul mask if necessary
if mask is not None:
mask = mask * attn_mask
else:
mask = attn_mask
return mask
class ReformerSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
all_head_size = config.num_attention_heads * config.attention_head_size
self.dropout = config.hidden_dropout_prob
self.dense = nn.Linear(all_head_size, config.hidden_size, bias=False)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
return hidden_states
class ReformerAttention(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.layer_id = layer_id
self.attn_layers = config.attn_layers
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
if len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "lsh":
self.self_attention = LSHSelfAttention(config)
elif len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "local":
self.self_attention = LocalSelfAttention(config)
elif len(set(self.attn_layers)) == 2 and set(self.attn_layers) == set(["lsh", "local"]):
# get correct attn layers
if self.attn_layers[self.layer_id] == "lsh":
self.self_attention = LSHSelfAttention(config)
else:
self.self_attention = LocalSelfAttention(config)
else:
raise NotImplementedError(
"Only attn layer types 'lsh' and 'local' exist, but got `config.attn_layers`: {}. Select attn layer types from ['lsh', 'local'] only.".format(
self.attn_layers
)
)
self.output = ReformerSelfOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
num_hashes=None,
do_output_attentions=False,
buckets=None,
):
hidden_states = self.layer_norm(hidden_states)
# use cached buckets for backprob if buckets not None for LSHSelfAttention
self_attention_outputs = self.self_attention(
hidden_states=hidden_states,
head_mask=head_mask,
attention_mask=attention_mask,
num_hashes=num_hashes,
do_output_attentions=do_output_attentions,
buckets=buckets,
)
attention_output = self.output(self_attention_outputs.hidden_states)
# add buckets if necessary
if hasattr(self_attention_outputs, "buckets"):
buckets = self_attention_outputs.buckets
else:
buckets = None
return AttentionOutput(
hidden_states=attention_output, attention_probs=self_attention_outputs.attention_probs, buckets=buckets,
)
class ReformerFeedForwardDense(nn.Module):
def __init__(self, config):
super().__init__()
self.dropout = config.hidden_dropout_prob
if isinstance(config.hidden_act, str):
self.act_fn = ACT2FN[config.hidden_act]
else:
self.act_fn = config.hidden_act
self.dense = nn.Linear(config.hidden_size, config.feed_forward_size)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = self.act_fn(hidden_states)
return hidden_states
class ReformerFeedForwardOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dropout = config.hidden_dropout_prob
self.dense = nn.Linear(config.feed_forward_size, config.hidden_size)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
return hidden_states
class ChunkReformerFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dense = ReformerFeedForwardDense(config)
self.output = ReformerFeedForwardOutput(config)
def forward(self, attention_output):
return apply_chunking_to_forward(
self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output,
)
def forward_chunk(self, hidden_states):
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dense(hidden_states)
return self.output(hidden_states)
class ReformerLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.attention = ReformerAttention(config, layer_id)
# dropout requires to have the same
# seed for forward and backward pass
self.attention_seed = None
self.feed_forward_seed = None
self.feed_forward = ChunkReformerFeedForward(config)
def _init_attention_seed(self):
"""
This function sets a new seed for the
attention layer to make dropout deterministic
for both forward calls: 1 normal forward
call and 1 forward call in backward
to recalculate activations.
"""
# randomize seeds
if next(self.parameters()).device.type == "cuda":
# GPU
device_idx = torch.cuda.current_device()
self.attention_seed = torch.cuda.default_generators[device_idx].seed()
torch.cuda.manual_seed(self.attention_seed)
else:
# CPU
self.attention_seed = int(torch.seed() % sys.maxsize)
torch.manual_seed(self.attention_seed)
def _init_feed_forward_seed(self):
"""
This function sets a new seed for the
feed forward layer to make dropout deterministic
for both forward calls: 1 normal forward
call and 1 forward call in backward
to recalculate activations.
"""
# randomize seeds
if next(self.parameters()).device.type == "cuda":
# GPU
device_idx = torch.cuda.current_device()
self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed()
torch.cuda.manual_seed(self.feed_forward_seed)
else:
# CPU
self.feed_forward_seed = int(torch.seed() % sys.maxsize)
torch.manual_seed(self.feed_forward_seed)
def forward(
self,
prev_attn_output,
hidden_states,
attention_mask=None,
head_mask=None,
num_hashes=None,
do_output_attentions=False,
):
with torch.no_grad():
# every forward pass we sample a different seed
# for dropout and save for forward fn in backward pass
# to have correct dropout
self._init_attention_seed()
attn_outputs = self.attention(
hidden_states=hidden_states,
head_mask=head_mask,
attention_mask=attention_mask,
num_hashes=num_hashes,
do_output_attentions=do_output_attentions,
)
attn_output = attn_outputs.hidden_states
# Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0)
# Y_1 = X_1 + f(X_2)
attn_output = prev_attn_output + attn_output
# free memory
del prev_attn_output
# every forward pass we sample a different seed
# for dropout and save seed for forward fn in backward
# to have correct dropout
self._init_feed_forward_seed()
# Y_2 = X_2 + g(Y_1)
hidden_states = hidden_states + self.feed_forward(attn_output)
return ReformerOutput(
attn_output=attn_output,
hidden_states=hidden_states,
attention_probs=attn_outputs.attention_probs,
buckets=attn_outputs.buckets,
)
def backward_pass(
self,
next_attn_output,
hidden_states,
grad_attn_output,
grad_hidden_states,
attention_mask=None,
head_mask=None,
buckets=None,
):
# Implements the backward pass for reversible ResNets.
# A good blog post on how this works can be found here:
# Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0)
# This code is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py
with torch.enable_grad():
next_attn_output.requires_grad = True
# set seed to have correct dropout
torch.manual_seed(self.feed_forward_seed)
# g(Y_1)
res_hidden_states = self.feed_forward(next_attn_output)
res_hidden_states.backward(grad_hidden_states, retain_graph=True)
with torch.no_grad():
# X_2 = Y_2 - g(Y_1)
hidden_states = hidden_states - res_hidden_states
del res_hidden_states
grad_attn_output = grad_attn_output + next_attn_output.grad
next_attn_output.grad = None
with torch.enable_grad():
hidden_states.requires_grad = True
# set seed to have correct dropout
torch.manual_seed(self.attention_seed)
# f(X_2)
# use cached buckets for backprob if buckets not None for LSHSelfAttention
output = self.attention(
hidden_states=hidden_states, head_mask=head_mask, attention_mask=attention_mask, buckets=buckets,
).hidden_states
output.backward(grad_attn_output, retain_graph=True)
with torch.no_grad():
# X_1 = Y_1 - f(X_2)
attn_output = next_attn_output - output
del output, next_attn_output
grad_hidden_states = grad_hidden_states + hidden_states.grad
hidden_states.grad = None
hidden_states = hidden_states.detach()
return ReformerBackwardOutput(
attn_output=attn_output,
hidden_states=hidden_states,
grad_attn_output=grad_attn_output,
grad_hidden_states=grad_hidden_states,
)
class _ReversibleFunction(Function):
"""
To prevent PyTorch from performing the usual backpropagation,
a customized backward function is implemented here. This way
it is made sure that no memory expensive activations are
saved during the forward pass.
This function is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py
"""
@staticmethod
def forward(
ctx,
hidden_states,
layers,
attention_mask,
head_mask,
num_hashes,
all_hidden_states,
all_attentions,
do_output_hidden_states,
do_output_attentions,
):
all_buckets = ()
# split duplicated tensor
hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1)
for layer, layer_head_mask in zip(layers, head_mask):
if do_output_hidden_states is True:
all_hidden_states.append(hidden_states)
layer_outputs = layer(
prev_attn_output=attn_output,
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=layer_head_mask,
num_hashes=num_hashes,
do_output_attentions=do_output_attentions,
)
attn_output = layer_outputs.attn_output
hidden_states = layer_outputs.hidden_states
all_buckets = all_buckets + (layer_outputs.buckets,)
if do_output_attentions:
all_attentions.append(layer_outputs.attention_probs)
# Add last layer
if do_output_hidden_states is True:
all_hidden_states.append(hidden_states)
# attach params to ctx for backward
ctx.save_for_backward(attn_output.detach(), hidden_states.detach())
ctx.layers = layers
ctx.all_buckets = all_buckets
ctx.head_mask = head_mask
ctx.attention_mask = attention_mask
# Concatenate 2 RevNet outputs
return torch.cat([attn_output, hidden_states], dim=-1)
@staticmethod
def backward(ctx, grad_hidden_states):
grad_attn_output, grad_hidden_states = torch.chunk(grad_hidden_states, 2, dim=-1)
# retrieve params from ctx for backward
attn_output, hidden_states = ctx.saved_tensors
# create tuple
output = ReformerBackwardOutput(
attn_output=attn_output,
hidden_states=hidden_states,
grad_attn_output=grad_attn_output,
grad_hidden_states=grad_hidden_states,
)
# free memory
del grad_attn_output, grad_hidden_states, attn_output, hidden_states
layers = ctx.layers
all_buckets = ctx.all_buckets
head_mask = ctx.head_mask
attention_mask = ctx.attention_mask
for idx, layer in enumerate(layers[::-1]):
# pop last buckets from stack
buckets = all_buckets[-1]
all_buckets = all_buckets[:-1]
# backprop
output = layer.backward_pass(
next_attn_output=output.attn_output,
hidden_states=output.hidden_states,
grad_attn_output=output.grad_attn_output,
grad_hidden_states=output.grad_hidden_states,
head_mask=head_mask[len(layers) - idx - 1],
attention_mask=attention_mask,
buckets=buckets,
)
assert all_buckets == (), "buckets have to be empty after backpropagation"
grad_hidden_states = torch.cat([output.grad_attn_output, output.grad_hidden_states], dim=-1)
# num of return vars has to match num of forward() args
# return gradient for hidden_states arg and None for other args
return grad_hidden_states, None, None, None, None, None, None, None, None
class ReformerEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.dropout = config.hidden_dropout_prob
self.layers = nn.ModuleList([ReformerLayer(config, i) for i in range(config.num_hidden_layers)])
# Reformer is using Rev Nets, thus last layer outputs are concatenated and
# Layer Norm is done over 2 * hidden_size
self.layer_norm = nn.LayerNorm(2 * config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
num_hashes=None,
do_output_hidden_states=False,
do_output_attentions=False,
):
# hidden_states and attention lists to be filled if wished
all_hidden_states = []
all_attentions = []
# concat same tensor for reversible ResNet
hidden_states = torch.cat([hidden_states, hidden_states], dim=-1)
hidden_states = _ReversibleFunction.apply(
hidden_states,
self.layers,
attention_mask,
head_mask,
num_hashes,
all_hidden_states,
all_attentions,
do_output_hidden_states,
do_output_attentions,
)
# Apply layer norm to concatenated hidden states
hidden_states = self.layer_norm(hidden_states)
# Apply dropout
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
return ReformerEncoderOutput(
hidden_states=hidden_states, all_hidden_states=all_hidden_states, all_attentions=all_attentions
)
class ReformerOnlyLMHead(nn.Module):
def __init__(self, config):
super().__init__()
# Reformer is using Rev Nets, thus last layer outputs are concatenated and
# Layer Norm is done over 2 * hidden_size
self.seq_len_dim = 1
self.chunk_size_lm_head = config.chunk_size_lm_head
self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
def forward_chunk(self, hidden_states):
hidden_states = self.decoder(hidden_states)
return hidden_states
class ReformerPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = ReformerConfig
pretrained_model_archive_map = REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "reformer"
@property
def dummy_inputs(self):
input_ids = torch.tensor(DUMMY_INPUTS)
input_mask = torch.tensor(DUMMY_MASK)
dummy_inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
}
return dummy_inputs
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, AxialPositionEmbeddings):
for weight in module.weights:
torch.nn.init.normal_(weight, std=self.config.axial_norm_std)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
REFORMER_START_DOCSTRING = r"""
Reformer was proposed in
`Reformer: The Efficient Transformer`_
by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
.. _`Reformer: The Efficient Transformer`:
https://arxiv.org/abs/2001.04451
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
usage and behavior.
Parameters:
config (:class:`~transformers.ReformerConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
REFORMER_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
During training the input_ids sequence_length has to be a multiple of the relevant model's
chunk lengths (lsh's, local's or both). During evaluation, the indices are automatically
padded to be a multiple of the chunk length.
Indices can be obtained using :class:`transformers.ReformerTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
`What are attention masks? <../glossary.html#attention-mask>`__
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
`What are position IDs? <../glossary.html#position-ids>`_
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
num_hashes (:obj:`int`, `optional`, defaults to :obj:`None`):
`num_hashes` is the number of hashing rounds that should be performed during
bucketing. Setting `num_hashes` overwrites the default `num_hashes` defined
in `config.num_hashes`.
For more information, see `num_hashes` in :class:`transformers.ReformerConfig`.
"""
@add_start_docstrings(
"The bare Reformer Model transformer outputting raw hidden-states" "without any specific head on top.",
REFORMER_START_DOCSTRING,
)
class ReformerModel(ReformerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
assert (
self.config.num_hidden_layers > 0
), "`config.attn_layers` is empty. Select at least one attn layer form ['lsh', 'local']"
self.embeddings = ReformerEmbeddings(config)
self.encoder = ReformerEncoder(config)
self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
num_hashes=None,
do_output_hidden_states=False,
do_output_attentions=False,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``do_output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import ReformerModel, ReformerTokenizer
import torch
tokenizer = ReformerTokenizer.from_pretrained('bert-base-uncased')
model = ReformerModel.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
# TODO(PVP): delete when PR to change output_attentions is made
do_output_attentions = self.config.output_attentions
do_output_hidden_states = self.config.output_hidden_states
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size() # noqa: F841
device = input_ids.device
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] # noqa: F841
device = inputs_embeds.device
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
assert (
len(input_shape) == 2
), "`input_ids` have be of shape `[batch_size, sequence_length]`, but got shape: {}".format(input_shape)
# prepare head mask
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers, is_attention_chunked=True)
# original sequence length for padding
orig_sequence_length = input_shape[-1]
# if needs padding
least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config)
must_pad_to_match_chunk_length = input_shape[-1] % least_common_mult_chunk_length != 0
if must_pad_to_match_chunk_length:
padding_length = least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length
if self.training is True:
raise ValueError(
"If training, sequence Length {} has to be a multiple of least common multiple chunk_length {}. Please consider padding the input to a length of {}.".format(
input_shape[-2], least_common_mult_chunk_length, input_shape[-2] + padding_length
)
)
# pad input
input_ids, inputs_embeds, attention_mask, position_ids, input_shape = self._pad_to_mult_of_chunk_length(
input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
input_shape=input_shape,
padding_length=padding_length,
padded_seq_length=least_common_mult_chunk_length,
device=device,
)
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds)
encoder_outputs = self.encoder(
hidden_states=embedding_output,
head_mask=head_mask,
attention_mask=attention_mask,
num_hashes=num_hashes,
do_output_hidden_states=do_output_hidden_states,
do_output_attentions=do_output_attentions,
)
sequence_output = encoder_outputs.hidden_states
# if padding was applied
if must_pad_to_match_chunk_length:
sequence_output = sequence_output[:, :orig_sequence_length]
outputs = (sequence_output,)
# TODO(PVP): Replace by named tuple after namedtuples are introduced in the library.
if do_output_hidden_states is True:
outputs = outputs + (encoder_outputs.all_hidden_states,)
if do_output_attentions is True:
outputs = outputs + (encoder_outputs.all_attentions,)
return outputs
def _pad_to_mult_of_chunk_length(
self,
input_ids,
inputs_embeds=None,
attention_mask=None,
position_ids=None,
input_shape=None,
padding_length=None,
padded_seq_length=None,
device=None,
):
logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.chunk_length`: {}".format(
input_shape[-1], input_shape[-1] + padding_length, padded_seq_length
)
)
padded_input_ids = torch.full(
(input_shape[0], padding_length), self.config.pad_token_id, device=device, dtype=torch.long,
)
# Extend `attention_mask`
if attention_mask is not None:
attention_mask = torch.cat(
[
attention_mask,
torch.zeros(input_shape[0], padding_length, device=device, dtype=attention_mask.dtype,),
],
dim=-1,
)
else:
attention_mask = torch.cat(
[
torch.ones(input_shape, device=device, dtype=torch.uint8),
torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.uint8),
],
dim=-1,
)
# Extend `input_ids` with padding to match least common multiple chunk_length
if input_ids is not None:
input_ids = torch.cat([input_ids, padded_input_ids], dim=-1)
input_shape = input_ids.size()
# Pad position ids if given
if position_ids is not None:
padded_position_ids = torch.arange(input_shape[-1], padded_seq_length, dtype=torch.long, device=device)
padded_position_ids = position_ids.unsqueeze(0).expand(input_shape[0], padding_length)
position_ids = torch.cat([position_ids, padded_position_ids], dim=-1)
# Extend `input_embeds` with padding to match least common multiple chunk_length
if inputs_embeds is not None:
padded_inputs_embeds = self.embeddings(padded_input_ids, position_ids)
inputs_embeds = torch.cat([inputs_embeds, padded_inputs_embeds], dim=-2)
input_shape = inputs_embeds.size()
return input_ids, inputs_embeds, attention_mask, position_ids, input_shape
@add_start_docstrings("""Reformer Model with a `language modeling` head on top. """, REFORMER_START_DOCSTRING)
class ReformerModelWithLMHead(ReformerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.reformer = ReformerModel(config)
self.lm_head = ReformerOnlyLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.lm_head.decoder
def tie_weights(self):
# word embeddings are not tied in Reformer
pass
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
position_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
num_hashes=None,
labels=None,
do_output_hidden_states=False,
do_output_attentions=False,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[-100, 0, ..., config.vocab_size - 1]`.
All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`lm_label` is provided):
Classification loss (cross entropy).
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``do_output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import ReformerModel, ReformerTokenizer
import torch
tokenizer = ReformerTokenizer.from_pretrained('google/reformer-crime-and-punishment')
model = ReformerModelWithLMHead.from_pretrained('google/reformer-crime-and-punishment')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=input_ids)
loss, prediction_scores = outputs[:2]
"""
reformer_outputs = self.reformer(
input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
num_hashes=num_hashes,
do_output_hidden_states=do_output_hidden_states,
do_output_attentions=do_output_attentions,
)
sequence_output = reformer_outputs[0]
logits = self.lm_head(sequence_output)
outputs = (logits,) + reformer_outputs[1:]
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
outputs = (loss,) + outputs
return outputs # (lm_loss), lm_logits, (hidden_states), (attentions)
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
# TODO(PVP): Add smart caching
inputs_dict = {"input_ids": input_ids}
if "num_hashes" in kwargs:
inputs_dict["num_hashes"] = kwargs["num_hashes"]
return inputs_dict
......@@ -13,8 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
import inspect
import logging
import os
from typing import Callable, Tuple
......@@ -175,7 +175,7 @@ class ModuleUtilsMixin:
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def get_head_mask(self, head_mask, num_hidden_layers):
def get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False):
"""
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
......@@ -189,6 +189,8 @@ class ModuleUtilsMixin:
"""
if head_mask is not None:
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
if is_attention_chunked is True:
head_mask = head_mask.unsqueeze(-1)
else:
head_mask = [None] * num_hidden_layers
......@@ -786,6 +788,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
attention_mask=None,
decoder_start_token_id=None,
use_cache=None,
**model_specific_kwargs
):
r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
......@@ -863,6 +866,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
use_cache: (`optional`) bool
If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
model_specific_kwargs: (`optional`) dict
Additional model specific kwargs will be forwarded to the `forward` function of the model.
Return:
output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
......@@ -1116,6 +1122,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
model_specific_kwargs=model_specific_kwargs,
)
else:
output = self._generate_no_beam_search(
......@@ -1138,6 +1145,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
model_specific_kwargs=model_specific_kwargs,
)
return output
......@@ -1163,6 +1171,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_outputs,
attention_mask,
use_cache,
model_specific_kwargs,
):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
......@@ -1175,7 +1184,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
)
outputs = self(**model_inputs)
......@@ -1288,6 +1297,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_outputs,
attention_mask,
use_cache,
model_specific_kwargs,
):
""" Generate sequences for each example with beam search.
"""
......@@ -1314,7 +1324,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
......@@ -2087,3 +2097,66 @@ def prune_layer(layer, index, dim=None):
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
else:
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
def apply_chunking_to_forward(
chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors
) -> torch.Tensor:
"""
This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`.
It then applies a layer `forward_fn` to each chunk independently to save memory.
If the `forward_fn` is independent across the `chunk_dim` this function will yield the
same result as not applying it.
Args:
chunk_size: int - the chunk size of a chunked tensor. `num_chunks` = `len(input_tensors[0]) / chunk_size`
chunk_dim: int - the dimension over which the input_tensors should be chunked
forward_fn: fn - the forward fn of the model
input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked
Returns:
a Tensor with the same shape the foward_fn would have given if applied
Examples::
# rename the usual forward() fn to forward_chunk()
def forward_chunk(self, hidden_states):
hidden_states = self.decoder(hidden_states)
return hidden_states
# implement a chunked forward function
def forward(self, hidden_states):
return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
"""
assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
tensor_shape = input_tensors[0].shape
assert all(
input_tensor.shape == tensor_shape for input_tensor in input_tensors
), "All input tenors have to be of the same shape"
# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
assert num_args_in_forward_chunk_fn == len(
input_tensors
), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format(
num_args_in_forward_chunk_fn, len(input_tensors)
)
if chunk_size > 0:
assert (
input_tensors[0].shape[chunk_dim] % chunk_size == 0
), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
input_tensors[0][chunk_dim], chunk_size
)
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
# chunk input tensor into tuples
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
# apply forward fn to every tuple
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
# concatenate output at same dimension
return torch.cat(output_chunks, dim=chunk_dim)
return forward_fn(*input_tensors)
......@@ -30,6 +30,7 @@ from .configuration_auto import (
FlaubertConfig,
GPT2Config,
OpenAIGPTConfig,
ReformerConfig,
RobertaConfig,
T5Config,
TransfoXLConfig,
......@@ -49,6 +50,7 @@ from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_reformer import ReformerTokenizer
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLTokenizer, TransfoXLTokenizerFast
......@@ -69,6 +71,7 @@ TOKENIZER_MAPPING = OrderedDict(
(XLMRobertaConfig, (XLMRobertaTokenizer, None)),
(BartConfig, (BartTokenizer, None)),
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
(ReformerConfig, (ReformerTokenizer, None)),
(ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)),
(BertConfig, (BertTokenizer, BertTokenizerFast)),
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
......
# coding=utf-8
# Copyright 2020 The Trax Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tokenization class for model Reformer."""
import logging
import os
from shutil import copyfile
from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
SPIECE_UNDERLINE = "▁"
####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__`
# to file names for serializing Tokenizer instances
####################################################
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__`
# to pretrained vocabulary URL for all the model shortcut names.
####################################################
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/spiece.model"
}
}
####################################################
# Mapping from model shortcut names to max length of inputs
####################################################
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"google/reformer-crime-and-punishment": 524288,
}
class ReformerTokenizer(PreTrainedTokenizer):
"""
Constructs an Reformer tokenizer. Based on `SentencePiece <https://github.com/google/sentencepiece>`__ .
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
should refer to the superclass for more information regarding methods.
Args:
vocab_file (:obj:`string`):
`SentencePiece <https://github.com/google/sentencepiece>`__ file (generally has a `.spm` extension) that
contains the vocabulary necessary to instantiate a tokenizer.
eos_token (:obj:`string`, `optional`, defaults to "</s>"):
The end of sequence token.
.. note::
When building a sequence using special tokens, this is not the token that is used for the end
of sequence. The token used is the :obj:`sep_token`.
unk_token (:obj:`string`, `optional`, defaults to "<unk>"):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (:obj:`string`, `optional`, defaults to "<pad>"):
The token used for padding, for example when batching sequences of different lengths.
additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`None`):
Additional special tokens used by the tokenizer.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(
self,
vocab_file,
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
additional_special_tokens=[],
**kwargs
):
super().__init__(
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
additional_special_tokens=additional_special_tokens,
**kwargs,
)
try:
import sentencepiece as spm
except ImportError:
logger.warning(
"You need to install SentencePiece to use ReformerTokenizer:"
"https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
raise
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(vocab_file)
@property
def vocab_size(self):
return self.sp_model.get_piece_size()
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
return state
def __setstate__(self, d):
self.__dict__ = d
try:
import sentencepiece as spm
except ImportError:
logger.warning(
"You need to install SentencePiece to use ReformerTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
raise
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file)
def _tokenize(self, text, sample=False):
""" Take as input a string and return a list of strings (tokens) for words/sub-words
"""
if not sample:
pieces = self.sp_model.EncodeAsPieces(text)
else:
pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
return pieces
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if index < self.sp_model.get_piece_size():
token = self.sp_model.IdToPiece(index)
return token
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = self.sp_model.decode_pieces(tokens)
return out_string
def save_vocabulary(self, save_directory):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
to a directory.
"""
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
......@@ -22,6 +22,8 @@ class TestActivations(unittest.TestCase):
get_activation("swish")
get_activation("relu")
get_activation("tanh")
get_activation("gelu_new")
get_activation("gelu_fast")
with self.assertRaises(KeyError):
get_activation("bogus")
with self.assertRaises(KeyError):
......
......@@ -125,6 +125,9 @@ class ModelTesterMixin:
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes:
config.output_attentions = True
......@@ -138,10 +141,17 @@ class ModelTesterMixin:
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
if chunk_length is not None:
self.assertListEqual(
list(attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
out_len = len(outputs)
if self.is_encoder_decoder:
......@@ -175,10 +185,16 @@ class ModelTesterMixin:
self_attentions = outputs[-1]
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
if chunk_length is not None:
self.assertListEqual(
list(self_attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def test_torchscript(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -465,14 +481,16 @@ class ModelTesterMixin:
self.assertEqual(model.config.output_attentions, False)
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
if hasattr(self.model_tester, "encoder_seq_length"):
seq_length = self.model_tester.encoder_seq_length
if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1:
seq_length = seq_length * self.model_tester.chunk_length
else:
seq_length = self.model_tester.seq_length
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[
self.model_tester.encoder_seq_length
if hasattr(self.model_tester, "encoder_seq_length")
else self.model_tester.seq_length,
self.model_tester.hidden_size,
],
list(hidden_states[0].shape[-2:]), [seq_length, self.model_tester.hidden_size],
)
def test_resize_tokens_embeddings(self):
......@@ -485,6 +503,9 @@ class ModelTesterMixin:
model = model_class(config)
model.to(torch_device)
if self.model_tester.is_training is False:
model.eval()
model_vocab_size = config.vocab_size
# Retrieve the embeddings and clone theme
model_embed = model.resize_token_embeddings(model_vocab_size)
......@@ -628,9 +649,13 @@ class ModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
# make sure that input_ids is at most of size 15
input_ids = input_ids[..., :15]
# iterate over all generative models
for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
model.eval()
if config.bos_token_id is None:
# if bos token id is not defined, model needs input_ids
......@@ -669,8 +694,12 @@ class ModelTesterMixin:
torch_device
)
# make sure that input_ids is at most of size 15
input_ids = input_ids[..., :15]
for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
model.eval()
if config.bos_token_id is None:
# if bos token id is not defined mobel needs input_ids, num_return_sequences = 1
......@@ -750,7 +779,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
def floats_tensor(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor of the shape within the vocab size."""
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
......
# coding=utf-8 # Copyright 2020 Huggingface
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():
from transformers import (
ReformerConfig,
ReformerModel,
ReformerModelWithLMHead,
ReformerTokenizer,
ReformerLayer,
REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
)
import torch
class ReformerModelTester:
def __init__(
self,
parent,
batch_size=None,
seq_length=None,
is_training=None,
is_decoder=None,
use_input_mask=None,
vocab_size=None,
attention_head_size=None,
hidden_size=None,
num_attention_heads=None,
local_attn_chunk_length=None,
local_num_chunks_before=None,
local_num_chunks_after=None,
num_buckets=None,
num_hashes=1,
lsh_attn_chunk_length=None,
lsh_num_chunks_before=None,
lsh_num_chunks_after=None,
chunk_size_lm_head=None,
chunk_size_feed_forward=None,
feed_forward_size=None,
hidden_act=None,
hidden_dropout_prob=None,
local_attention_probs_dropout_prob=None,
lsh_attention_probs_dropout_prob=None,
max_position_embeddings=None,
initializer_range=None,
axial_norm_std=None,
layer_norm_eps=None,
axial_pos_embds=None,
axial_pos_shape=None,
axial_pos_embds_dim=None,
attn_layers=None,
pad_token_id=None,
eos_token_id=None,
scope=None,
hash_seed=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.is_decoder = is_decoder
self.use_input_mask = use_input_mask
self.vocab_size = vocab_size
self.attention_head_size = attention_head_size
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = len(attn_layers)
self.local_attn_chunk_length = local_attn_chunk_length
self.local_num_chunks_after = local_num_chunks_after
self.local_num_chunks_before = local_num_chunks_before
self.num_hashes = num_hashes
self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets
self.lsh_attn_chunk_length = lsh_attn_chunk_length
self.lsh_num_chunks_after = lsh_num_chunks_after
self.lsh_num_chunks_before = lsh_num_chunks_before
self.hidden_act = hidden_act
self.feed_forward_size = feed_forward_size
self.hidden_dropout_prob = hidden_dropout_prob
self.local_attention_probs_dropout_prob = local_attention_probs_dropout_prob
self.lsh_attention_probs_dropout_prob = lsh_attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.axial_pos_embds = axial_pos_embds
self.axial_pos_shape = tuple(axial_pos_shape)
self.axial_pos_embds_dim = tuple(axial_pos_embds_dim)
self.axial_norm_std = axial_norm_std
self.chunk_size_lm_head = chunk_size_lm_head
self.chunk_size_feed_forward = chunk_size_feed_forward
self.scope = scope
self.attn_layers = attn_layers
self.pad_token_id = pad_token_id
self.hash_seed = hash_seed
attn_chunk_length = local_attn_chunk_length if local_attn_chunk_length is not None else lsh_attn_chunk_length
num_chunks_after = local_num_chunks_after if local_num_chunks_after is not None else lsh_num_chunks_after
num_chunks_before = local_num_chunks_before if local_num_chunks_before is not None else lsh_num_chunks_before
self.encoder_seq_length = seq_length // attn_chunk_length + (self.seq_length % attn_chunk_length != 0)
self.key_length = (num_chunks_before + num_chunks_after + 1) * attn_chunk_length
self.chunk_length = attn_chunk_length
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
config = ReformerConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
feed_forward_size=self.feed_forward_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
local_attention_probs_dropout_prob=self.local_attention_probs_dropout_prob,
lsh_attention_probs_dropout_prob=self.lsh_attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
is_decoder=self.is_decoder,
axial_pos_embds=self.axial_pos_embds,
axial_pos_shape=self.axial_pos_shape,
axial_pos_embds_dim=self.axial_pos_embds_dim,
local_attn_chunk_length=self.local_attn_chunk_length,
local_num_chunks_after=self.local_num_chunks_after,
local_num_chunks_before=self.local_num_chunks_before,
num_hashes=self.num_hashes,
num_buckets=self.num_buckets,
lsh_attn_chunk_length=self.lsh_attn_chunk_length,
lsh_num_chunks_after=self.lsh_num_chunks_after,
lsh_num_chunks_before=self.lsh_num_chunks_before,
attn_layers=self.attn_layers,
pad_token_id=self.pad_token_id,
hash_seed=self.hash_seed,
)
return (
config,
input_ids,
input_mask,
)
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_reformer_model(
self, config, input_ids, input_mask,
):
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()
(sequence_output,) = model(input_ids, attention_mask=input_mask)
(sequence_output,) = model(input_ids)
result = {
"sequence_output": sequence_output,
}
# 2 * hidden_size because we use reversible resnet layers
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size],
)
def create_and_check_reformer_model_with_lm_backward(
self, config, input_ids, input_mask,
):
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0]
loss.backward()
def create_and_check_reformer_with_lm(
self, config, input_ids, input_mask,
):
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.check_loss_output(result)
def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder):
# no special position embeddings
config.axial_pos_embds = False
config.is_decoder = is_decoder
if self.lsh_attn_chunk_length is not None:
# need to set chunk length equal sequence length to be certain that chunking works
config.lsh_attn_chunk_length = self.seq_length
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()
# set all position encodings to zero so that postions don't matter
with torch.no_grad():
embedding = model.embeddings.position_embeddings.embedding
embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device))
embedding.weight.requires_grad = False
half_seq_len = self.seq_length // 2
roll = self.chunk_length
half_input_ids = input_ids[:, :half_seq_len]
# normal padded
attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1,)
input_ids_padded = torch.cat(
[half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1,
)
# shifted padded
input_ids_roll = torch.cat(
[half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1,
)
input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1)
attn_mask_roll = torch.roll(attn_mask, roll, dims=-1)
output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len]
output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][:, roll : half_seq_len + roll]
self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3))
def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder):
config.is_decoder = is_decoder
layer = ReformerLayer(config).to(torch_device)
layer.train()
shape = (
self.batch_size,
self.seq_length,
config.hidden_size,
) # Batch x SeqLen x hiddenSize
# get random tensors
hidden_states = floats_tensor(shape)
prev_attn_output = floats_tensor(shape)
# now the random seeds for attention and feed forward is initialized
# forward tensors with dropout
layer_outputs = layer(prev_attn_output, hidden_states, attention_mask=input_mask)
next_attn_output = layer_outputs.attn_output
next_hidden_states = layer_outputs.hidden_states
torch.manual_seed(layer.attention_seed)
attn_outputs = layer.attention(hidden_states, attention_mask=input_mask)
self.parent.assertTrue(
torch.allclose(prev_attn_output + attn_outputs.hidden_states, next_attn_output, atol=1e-3,)
)
torch.manual_seed(layer.feed_forward_seed)
feed_forward_hidden_states = layer.feed_forward(next_attn_output)
self.parent.assertTrue(
torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,)
)
def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask):
torch.manual_seed(0)
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()
hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0]
config.chunk_size_lm_head = 1
config.chunk_size_feed_forward = 1
torch.manual_seed(0)
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask):
if not self.is_training:
return
# disable dropout
config.hidden_dropout_prob = 0
config.local_attention_probs_dropout_prob = 0
config.lsh_attention_probs_dropout_prob = 0
torch.manual_seed(0)
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.train()
model.zero_grad()
loss_no_chunk, output_no_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[:2]
loss_no_chunk.backward()
grad_slice_word_no_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
grad_slice_position_factor_1_no_chunk = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:]
grad_slice_position_factor_2_no_chunk = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5]
config.chunk_size_lm_head = 1
config.chunk_size_feed_forward = 1
torch.manual_seed(0)
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.train()
model.zero_grad()
loss_chunk, output_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[:2]
loss_chunk.backward()
grad_slice_word_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
grad_slice_position_factor_1_chunk = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:]
grad_slice_position_factor_2_chunk = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5]
self.parent.assertTrue(torch.allclose(loss_chunk, loss_no_chunk, atol=1e-3))
self.parent.assertTrue(torch.allclose(grad_slice_word_no_chunk, grad_slice_word_chunk, atol=1e-3))
self.parent.assertTrue(
torch.allclose(grad_slice_position_factor_1_chunk, grad_slice_position_factor_1_no_chunk, atol=1e-3)
)
self.parent.assertTrue(
torch.allclose(grad_slice_position_factor_2_chunk, grad_slice_position_factor_2_no_chunk, atol=1e-3)
)
def create_and_check_reformer_random_seed(self, config, input_ids, input_mask):
layer = ReformerLayer(config).to(torch_device)
layer.train()
shape = (
self.batch_size,
self.seq_length,
config.hidden_size,
) # Batch x SeqLen x hiddenSize
hidden_states = floats_tensor(shape)
attn_output = floats_tensor(shape)
seeds = []
for _ in range(100):
layer_outputs = layer(attn_output, hidden_states, attention_mask=input_mask)
attn_output = layer_outputs.attn_output
hidden_states = layer_outputs.hidden_states
torch.manual_seed(layer.attention_seed)
seeds.append(layer.attention_seed)
self.parent.assertGreater(len(set(seeds)), 70)
seeds = []
for _ in range(100):
layer_outputs = layer(attn_output, hidden_states, attention_mask=input_mask)
attn_output = layer_outputs.attn_output
hidden_states = layer_outputs.hidden_states
torch.manual_seed(layer.feed_forward_seed)
seeds.append(layer.feed_forward_seed)
self.parent.assertGreater(len(set(seeds)), 70)
def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask):
model = ReformerModel(config=config)
model.to(torch_device)
model.half()
model.eval()
output = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask):
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.half()
model.eval()
output = model.generate(input_ids, attention_mask=input_mask, do_sample=False)
self.parent.assertFalse(torch.isnan(output).any().item())
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask,) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
class ReformerTesterMixin:
"""
Reformer Local and Reformer LSH run essentially the same tests
"""
def test_config(self):
self.config_tester.run_common_tests()
def test_reformer_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model(*config_and_inputs)
def test_reformer_lm_model_backward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_with_lm_backward(*config_and_inputs)
def test_reformer_model_attn_masking(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, True)
self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, False)
def test_reformer_with_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_with_lm(*config_and_inputs)
def test_reformer_layer_training_dropout(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True)
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False)
def test_reformer_chunking_forward_equality(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_feed_forward_chunking(*config_and_inputs)
def test_reformer_chunking_backward_equality(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs)
@slow
def test_dropout_random_seed_is_changing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_random_seed(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_reformer_model_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_fp16_forward(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_reformer_model_fp16_generate(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs)
@require_torch
class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest.TestCase):
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False
test_headmasking = False
test_torchscript = False
def prepare_kwargs(self):
return {
"batch_size": 13,
"seq_length": 32,
"is_training": True,
"is_decoder": False,
"use_input_mask": True,
"vocab_size": 32,
"attention_head_size": 16,
"hidden_size": 32,
"num_attention_heads": 2,
"local_attn_chunk_length": 4,
"local_num_chunks_before": 1,
"local_num_chunks_after": 0,
"chunk_size_lm_head": 0,
"chunk_size_feed_forward": 0,
"feed_forward_size": 32,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"local_attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"initializer_range": 0.02,
"axial_norm_std": 1.0,
"layer_norm_eps": 1e-12,
"axial_pos_embds": True,
"axial_pos_shape": [4, 8],
"axial_pos_embds_dim": [16, 16],
"attn_layers": ["local", "local", "local", "local"],
"pad_token_id": 0,
"eos_token_id": 2,
"scope": None,
"hash_seed": 0,
}
def setUp(self):
tester_kwargs = self.prepare_kwargs()
self.model_tester = ReformerModelTester(self, **tester_kwargs)
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
@slow
def test_model_from_pretrained(self):
for model_name in list(REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = ReformerModelWithLMHead.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_torch
class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase, ReformerTesterMixin):
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False
test_headmasking = False
test_torchscript = False
def prepare_kwargs(self):
return {
"batch_size": 13,
"seq_length": 13,
"use_input_mask": True,
"is_training": False,
"is_decoder": False,
"vocab_size": 32,
"attention_head_size": 16,
"hidden_size": 64,
"num_attention_heads": 2,
"num_buckets": 2,
"num_hashes": 4,
"lsh_attn_chunk_length": 4,
"lsh_num_chunks_before": 2,
"lsh_num_chunks_after": 3,
"chunk_size_lm_head": 5,
"chunk_size_feed_forward": 6,
"feed_forward_size": 32,
"hidden_act": "relu",
"hidden_dropout_prob": 0.1,
"lsh_attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"initializer_range": 0.02,
"axial_norm_std": 1.0,
"layer_norm_eps": 1e-12,
"axial_pos_embds": True,
"axial_pos_shape": [4, 8],
"axial_pos_embds_dim": [16, 48],
"attn_layers": ["lsh", "lsh", "lsh", "lsh"],
"pad_token_id": 0,
"eos_token_id": 2,
"scope": None,
"hash_seed": 0,
}
def setUp(self):
tester_kwargs = self.prepare_kwargs()
self.model_tester = ReformerModelTester(self, **tester_kwargs)
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
@require_torch
class ReformerIntegrationTests(unittest.TestCase):
"""
These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/04/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `branch_to_save_trax_integration_tests`.
"""
def _get_basic_config_and_input(self):
config = {
"vocab_size": 320,
"attention_head_size": 8,
"hidden_size": 16,
"num_attention_heads": 2,
"num_buckets": 2,
"num_hashes": 4,
"lsh_attn_chunk_length": 4,
"local_attn_chunk_length": 4,
"lsh_num_chunks_before": 1,
"lsh_num_chunks_after": 0,
"local_num_chunks_before": 1,
"local_num_chunks_after": 0,
"chunk_size_lm_head": 0,
"chunk_size_feed_forward": 0,
"feed_forward_size": 32,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"lsh_attention_probs_dropout_prob": 0.0,
"local_attention_probs_dropout_prob": 0.0,
"max_position_embeddings": 32,
"initializer_range": 0.02,
"axial_norm_std": 1.0,
"layer_norm_eps": 1e-12,
"sinusoidal_pos_embds": False,
"axial_pos_embds": True,
"axial_pos_shape": [4, 8],
"axial_pos_embds_dim": [8, 8],
"hash_seed": 0,
"is_decoder": True,
}
return config
def _get_hidden_states(self):
return torch.tensor(
[
[
[
1.90826353e00,
-1.45999730e00,
-6.20405462e-01,
1.52503433e00,
-3.64464232e-01,
-8.27359235e-01,
8.39670803e-01,
2.44492178e-01,
4.98332758e-01,
2.69175139e00,
-7.08081422e-03,
1.04915401e00,
-1.83476661e00,
7.67220476e-01,
2.98580543e-01,
2.84803992e-02,
],
[
-2.66374286e-02,
4.33497576e-01,
3.10386309e-01,
5.46039944e-01,
-2.47292666e-04,
-7.52305019e-01,
2.39162103e-01,
7.25216186e-01,
-7.58357372e-01,
4.20635998e-01,
-4.04739919e-02,
1.59924145e-01,
2.05135748e00,
-1.15997978e00,
5.37166397e-01,
2.62873606e-01,
],
[
1.85247482e-01,
7.07046037e-01,
-6.77089715e-01,
-2.24209655e00,
-3.75307980e-02,
-8.59380874e-01,
-2.81027884e00,
1.01276376e00,
-1.69438001e00,
4.17574660e-01,
-1.49196962e00,
-1.76483717e00,
-1.94566312e-01,
-1.71183858e00,
7.72903565e-01,
-1.11557056e00,
],
[
9.46069193e-01,
1.53417623e-01,
-9.58686996e-01,
1.18126669e-01,
1.75967724e00,
1.62194590e00,
-5.74108159e-01,
6.79920443e-01,
5.44028163e-01,
2.05466114e-01,
-3.63045868e-01,
2.41865062e-01,
3.20348382e-01,
-9.05611176e-01,
-1.92690727e-01,
-1.19917547e00,
],
]
],
dtype=torch.float32,
device=torch_device,
)
def _get_attn_mask(self):
return torch.tensor([[0, 1, 0, 0]], dtype=torch.long, device=torch_device)
def _get_input_ids_and_mask(self):
mask = torch.tensor(
[
[1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0],
],
dtype=torch.long,
device=torch_device,
)
input_ids = torch.tensor(
[
[
89,
279,
286,
84,
194,
316,
182,
28,
283,
37,
169,
7,
253,
267,
107,
250,
44,
7,
102,
62,
3,
243,
171,
265,
302,
48,
164,
264,
148,
229,
280,
150,
],
[
9,
192,
66,
112,
163,
83,
135,
70,
224,
96,
31,
80,
196,
80,
63,
22,
85,
100,
47,
283,
0,
163,
126,
143,
195,
82,
53,
82,
18,
27,
182,
52,
],
],
dtype=torch.long,
device=torch_device,
)
return input_ids, mask
def test_lsh_layer_forward(self):
config = self._get_basic_config_and_input()
config["attn_layers"] = ["lsh"]
config["is_decoder"] = False
hidden_states = self._get_hidden_states()
torch.manual_seed(0)
layer = ReformerLayer(ReformerConfig(**config)).to(torch_device)
layer.eval()
reformer_output = layer(prev_attn_output=hidden_states.clone(), hidden_states=hidden_states)
output_slice = reformer_output.hidden_states[0, 0, :5]
expected_output_slice = torch.tensor(
[1.6879, -1.3083, -0.4708, 1.3555, -0.6292], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_lsh_layer_forward_complex(self):
config = self._get_basic_config_and_input()
config["attn_layers"] = ["lsh"]
config["num_buckets"] = [2, 4]
attn_mask = self._get_attn_mask()
hidden_states = self._get_hidden_states()
torch.manual_seed(0)
layer = ReformerLayer(ReformerConfig(**config)).to(torch_device)
layer.eval()
reformer_output = layer(
prev_attn_output=hidden_states.clone(), hidden_states=hidden_states, attention_mask=attn_mask,
)
output_slice = reformer_output.hidden_states[0, 0, :5]
expected_output_slice = torch.tensor(
[1.6439, -1.2306, -0.5108, 1.3006, -0.6537], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_local_layer_forward(self):
config = self._get_basic_config_and_input()
config["attn_layers"] = ["local"]
config["is_decoder"] = False
hidden_states = self._get_hidden_states()
torch.manual_seed(0)
layer = ReformerLayer(ReformerConfig(**config)).to(torch_device)
layer.eval()
reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states)
output_slice = reformer_output.hidden_states[0, 0, :5]
expected_output_slice = torch.tensor(
[1.4212, -2.0576, -0.9688, 1.4599, -0.1344], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_local_layer_forward_complex(self):
config = self._get_basic_config_and_input()
config["attn_layers"] = ["local"]
attn_mask = self._get_attn_mask()
hidden_states = self._get_hidden_states()
torch.manual_seed(0)
layer = ReformerLayer(ReformerConfig(**config)).to(torch_device)
layer.eval()
reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask,)
output_slice = reformer_output.hidden_states[0, 0, :5]
expected_output_slice = torch.tensor(
[1.5476, -1.9020, -0.9902, 1.5013, -0.1950], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_lsh_model_forward(self):
config = self._get_basic_config_and_input()
config["attn_layers"] = ["lsh", "lsh", "lsh", "lsh"]
config["num_buckets"] = [2, 4]
torch.manual_seed(0)
model = ReformerModel(ReformerConfig(**config)).to(torch_device)
model.eval()
input_ids, attn_mask = self._get_input_ids_and_mask()
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
output_slice = hidden_states[0, 0, :5]
expected_output_slice = torch.tensor(
[-0.9896, -0.9396, -1.0831, -0.0597, 0.2456], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_local_model_forward(self):
config = self._get_basic_config_and_input()
config["attn_layers"] = ["local", "local", "local", "local"]
torch.manual_seed(0)
model = ReformerModel(ReformerConfig(**config)).to(torch_device)
model.eval()
input_ids, attn_mask = self._get_input_ids_and_mask()
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
output_slice = hidden_states[0, 0, :5]
expected_output_slice = torch.tensor(
[-1.6791, 0.7171, 0.1594, 0.4063, 1.2584], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_lm_model_forward(self):
config = self._get_basic_config_and_input()
config["attn_layers"] = ["local", "lsh", "local", "lsh", "local", "lsh"]
config["num_buckets"] = [2, 4]
config["is_decoder"] = False
torch.manual_seed(0)
model = ReformerModelWithLMHead(ReformerConfig(**config)).to(torch_device)
model.eval()
input_ids, attn_mask = self._get_input_ids_and_mask()
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
output_slice = hidden_states[1, -1, :5]
expected_output_slice = torch.tensor(
[0.0324, -0.0121, 0.0615, 0.0031, -0.0297], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_local_lm_model_grad(self):
config = self._get_basic_config_and_input()
config["attn_layers"] = ["local", "local", "local", "local"]
config["hidden_dropout_prob"] = 0.0
config["local_attention_probs_dropout_prob"] = 0.0
torch.manual_seed(0)
model = ReformerModelWithLMHead(ReformerConfig(**config)).to(torch_device)
model.train()
model.zero_grad()
input_ids, _ = self._get_input_ids_and_mask()
loss = model(input_ids=input_ids, labels=input_ids)[0]
self.assertTrue(torch.allclose(loss, torch.tensor(5.7786, dtype=torch.float, device=torch_device), atol=1e-3))
loss.backward()
# check last grads to cover all proable errors
grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
expected_grad_slice_word = torch.tensor(
[-0.0005, 0.0001, 0.0002, 0.0003, 0.0006], dtype=torch.float, device=torch_device,
)
grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:]
expected_grad_slice_pos_fac_1 = torch.tensor(
[0.0037, -1.3793, -1.0231, -1.5230, -2.5306], dtype=torch.float, device=torch_device,
)
grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5]
expected_grad_slice_pos_fac_2 = torch.tensor(
[-1.3165, 0.5168, 0.7785, 1.0811, -0.9830], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(grad_slice_word, expected_grad_slice_word, atol=1e-3))
self.assertTrue(torch.allclose(grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, atol=1e-3))
self.assertTrue(torch.allclose(grad_slice_position_factor_2, expected_grad_slice_pos_fac_2, atol=1e-3))
def test_lsh_lm_model_grad(self):
config = self._get_basic_config_and_input()
config["attn_layers"] = ["lsh", "lsh", "lsh", "lsh"]
config["hidden_dropout_prob"] = 0.0
config["lsh_attention_probs_dropout_prob"] = 0.0
config["num_buckets"] = [2, 4]
config["num_hashes"] = 6
torch.manual_seed(0)
model = ReformerModelWithLMHead(ReformerConfig(**config)).to(torch_device)
model.train()
model.zero_grad()
input_ids, _ = self._get_input_ids_and_mask()
loss = model(input_ids=input_ids, labels=input_ids)[0]
self.assertTrue(torch.allclose(loss, torch.tensor(5.7819, dtype=torch.float, device=torch_device), atol=1e-3))
loss.backward()
# check last grads to cover all proable errors
grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
expected_grad_slice_word = torch.tensor(
[2.6357e-05, 4.3358e-04, -8.4985e-04, 1.0094e-04, 3.8954e-04], dtype=torch.float, device=torch_device,
)
grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:]
expected_grad_slice_pos_fac_1 = torch.tensor(
[-0.0984, 0.6283, 0.4282, 1.2960, 0.6897], dtype=torch.float, device=torch_device,
)
grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5]
expected_grad_slice_pos_fac_2 = torch.tensor(
[0.4626, -0.0231, -0.0172, 0.1081, 0.3805], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(grad_slice_word, expected_grad_slice_word, atol=1e-3))
self.assertTrue(torch.allclose(grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, atol=1e-3))
self.assertTrue(torch.allclose(grad_slice_position_factor_2, expected_grad_slice_pos_fac_2, atol=1e-3))
@slow
def test_pretrained_generate_crime_and_punish(self):
model = ReformerModelWithLMHead.from_pretrained("google/reformer-crime-and-punishment").to(torch_device)
tokenizer = ReformerTokenizer.from_pretrained("google/reformer-crime-and-punishment")
model.eval()
input_ids = tokenizer.encode("A few months later", return_tensors="pt").to(torch_device)
output_ids = model.generate(
input_ids, max_length=50, num_beams=4, early_stopping=True, do_sample=False, num_hashes=8
)
output_text = tokenizer.decode(output_ids[0])
self.assertEqual(
output_text,
"A few months later state expression in his ideas, at the first entrance. He was positively for an inst",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment