Unverified Commit dc3f6758 authored by Vasudev Gupta's avatar Vasudev Gupta Committed by GitHub
Browse files

Add BigBirdPegasus (#10991)



* init bigbird pegasus

* add debugging nb ; update config

* init conversion

* update conversion script

* complete conversion script

* init forward()

* complete forward()

* add tokenizer

* add some slow tests

* commit current

* fix copies

* add docs

* add conversion script for bigbird-roberta-summarization

* remove TODO

* small fixups

* correct tokenizer

* add bigbird core for now

* fix config

* fix more

* revert pegasus-tokenizer back

* make style

* everything working for pubmed; yayygit status

* complete tests finally

* remove bigbird pegasus tok

* correct tokenizer

* correct tests

* add tokenizer files

* finish make style

* fix test

* update

* make style

* fix tok utils base file

* make fix-copies

* clean a bit

* small update

* fix some suggestions

* add to readme

* fix a bit, clean tests

* fix more tests

* Update src/transformers/__init__.py

* Update src/transformers/__init__.py

* make fix-copies

* complete attn switching, auto-padding left

* make style

* fix auto-padding test

* make style

* fix batched attention tests

* put tolerance at 1e-1 for stand-alone decoder test

* fix docs

* fix tests

* correct slow tokenizer conversion

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* complete remaining suggestions

* fix test
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 6f40e317
......@@ -195,6 +195,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
1. **[BERT](https://huggingface.co/transformers/model_doc/bert.html)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
1. **[BERT For Sequence Generation](https://huggingface.co/transformers/model_doc/bertgeneration.html)** (from Google) released with the paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
1. **[BigBird-RoBERTa](https://huggingface.co/transformers/model_doc/bigbird.html)** (from Google Research) released with the paper [Big Bird: Transformers for Longer Sequences](https://arxiv.org/abs/2007.14062) by Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed.
1. **[BigBird-Pegasus](https://huggingface.co/transformers/model_doc/bigbird_pegasus.html)** (from Google Research) released with the paper [Big Bird: Transformers for Longer Sequences](https://arxiv.org/abs/2007.14062) by Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed.
1. **[Blenderbot](https://huggingface.co/transformers/model_doc/blenderbot.html)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
1. **[BlenderbotSmall](https://huggingface.co/transformers/model_doc/blenderbot_small.html)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
1. **[BORT](https://huggingface.co/transformers/model_doc/bort.html)** (from Alexa) released with the paper [Optimal Subarchitecture Extraction For BERT](https://arxiv.org/abs/2010.10499) by Adrian de Wynter and Daniel J. Perry.
......
This diff is collapsed.
..
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
BigBirdPegasus
-----------------------------------------------------------------------------------------------------------------------
Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The BigBird model was proposed in `Big Bird: Transformers for Longer Sequences <https://arxiv.org/abs/2007.14062>`__ by
Zaheer, Manzil and Guruganesh, Guru and Dubey, Kumar Avinava and Ainslie, Joshua and Alberti, Chris and Ontanon,
Santiago and Pham, Philip and Ravula, Anirudh and Wang, Qifan and Yang, Li and others. BigBird, is a sparse-attention
based transformer which extends Transformer based models, such as BERT to much longer sequences. In addition to sparse
attention, BigBird also applies global attention as well as random attention to the input sequence. Theoretically, it
has been shown that applying sparse, global, and random attention approximates full attention, while being
computationally much more efficient for longer sequences. As a consequence of the capability to handle longer context,
BigBird has shown improved performance on various long document NLP tasks, such as question answering and
summarization, compared to BERT or RoBERTa.
The abstract from the paper is the following:
*Transformers-based models, such as BERT, have been one of the most successful deep learning models for NLP.
Unfortunately, one of their core limitations is the quadratic dependency (mainly in terms of memory) on the sequence
length due to their full attention mechanism. To remedy this, we propose, BigBird, a sparse attention mechanism that
reduces this quadratic dependency to linear. We show that BigBird is a universal approximator of sequence functions and
is Turing complete, thereby preserving these properties of the quadratic, full attention model. Along the way, our
theoretical analysis reveals some of the benefits of having O(1) global tokens (such as CLS), that attend to the entire
sequence as part of the sparse attention mechanism. The proposed sparse attention can handle sequences of length up to
8x of what was previously possible using similar hardware. As a consequence of the capability to handle longer context,
BigBird drastically improves performance on various NLP tasks such as question answering and summarization. We also
propose novel applications to genomics data.*
Tips:
- For an in-detail explanation on how BigBird's attention works, see `this blog post
<https://huggingface.co/blog/big-bird>`__.
- BigBird comes with 2 implementations: **original_full** & **block_sparse**. For the sequence length < 1024, using
**original_full** is advised as there is no benefit in using **block_sparse** attention.
- The code currently uses window size of 3 blocks and 2 global blocks.
- Sequence length must be divisible by block size.
- Current implementation supports only **ITC**.
- Current implementation doesn't support **num_random_blocks = 0**.
- BigBirdPegasus uses the `PegasusTokenizer
<https://github.com/huggingface/transformers/blob/master/src/transformers/models/pegasus/tokenization_pegasus.py>`__.
The original code can be found `here <https://github.com/google-research/bigbird>`__.
BigBirdPegasusConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BigBirdPegasusConfig
:members:
BigBirdPegasusModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BigBirdPegasusModel
:members: forward
BigBirdPegasusForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BigBirdPegasusForConditionalGeneration
:members: forward
BigBirdPegasusForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BigBirdPegasusForSequenceClassification
:members: forward
BigBirdPegasusForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BigBirdPegasusForQuestionAnswering
:members: forward
BigBirdPegasusForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BigBirdPegasusForCausalLM
:members: forward
......@@ -155,6 +155,10 @@ _import_structure = {
"models.bert_japanese": ["BertJapaneseTokenizer", "CharacterTokenizer", "MecabTokenizer"],
"models.bertweet": ["BertweetTokenizer"],
"models.big_bird": ["BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdConfig", "BigBirdTokenizer"],
"models.bigbird_pegasus": [
"BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP",
"BigBirdPegasusConfig",
],
"models.blenderbot": ["BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BlenderbotConfig", "BlenderbotTokenizer"],
"models.blenderbot_small": [
"BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
......@@ -543,6 +547,16 @@ if is_torch_available():
"load_tf_weights_in_big_bird",
]
)
_import_structure["models.bigbird_pegasus"].extend(
[
"BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST",
"BigBirdPegasusForCausalLM",
"BigBirdPegasusForConditionalGeneration",
"BigBirdPegasusForQuestionAnswering",
"BigBirdPegasusForSequenceClassification",
"BigBirdPegasusModel",
]
)
_import_structure["models.blenderbot"].extend(
[
"BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -1541,6 +1555,7 @@ if TYPE_CHECKING:
from .models.bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
from .models.bertweet import BertweetTokenizer
from .models.big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig, BigBirdTokenizer
from .models.bigbird_pegasus import BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdPegasusConfig
from .models.blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig, BlenderbotTokenizer
from .models.blenderbot_small import (
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
......@@ -1885,6 +1900,14 @@ if TYPE_CHECKING:
BigBirdPreTrainedModel,
load_tf_weights_in_big_bird,
)
from .models.bigbird_pegasus import (
BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST,
BigBirdPegasusForCausalLM,
BigBirdPegasusForConditionalGeneration,
BigBirdPegasusForQuestionAnswering,
BigBirdPegasusForSequenceClassification,
BigBirdPegasusModel,
)
from .models.blenderbot import (
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotForCausalLM,
......
......@@ -635,9 +635,17 @@ class PegasusConverter(SpmConverter):
vocab = [
(self.original_tokenizer.pad_token, 0.0),
(self.original_tokenizer.eos_token, 0.0),
(self.original_tokenizer.mask_token_sent, 0.0),
(self.original_tokenizer.mask_token, 0.0),
]
if self.original_tokenizer.mask_token_sent is not None:
vocab += [(self.original_tokenizer.mask_token_sent, 0.0)]
if (
self.original_tokenizer.mask_token is not None
and self.original_tokenizer.mask_token_id < self.original_tokenizer.offset
):
vocab += [(self.original_tokenizer.mask_token, 0.0)]
vocab += [(f"<unk_{i}>", -100.0) for i in range(2, self.original_tokenizer.offset)]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
return vocab
......
......@@ -26,6 +26,7 @@ from . import (
bert_japanese,
bertweet,
big_bird,
bigbird_pegasus,
blenderbot,
blenderbot_small,
camembert,
......
......@@ -23,6 +23,10 @@ from ..bart.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartCo
from ..bert.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from ..bert_generation.configuration_bert_generation import BertGenerationConfig
from ..big_bird.configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig
from ..bigbird_pegasus.configuration_bigbird_pegasus import (
BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP,
BigBirdPegasusConfig,
)
from ..blenderbot.configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig
from ..blenderbot_small.configuration_blenderbot_small import (
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
......@@ -86,6 +90,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
(key, value)
for pretrained_map in [
# Add archive maps here
BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP,
DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP,
......@@ -139,6 +144,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
CONFIG_MAPPING = OrderedDict(
[
# Add configs here
("bigbird_pegasus", BigBirdPegasusConfig),
("deit", DeiTConfig),
("luke", LukeConfig),
("gpt_neo", GPTNeoConfig),
......@@ -198,6 +204,7 @@ CONFIG_MAPPING = OrderedDict(
MODEL_NAMES_MAPPING = OrderedDict(
[
# Add full (and cased) model names here
("bigbird_pegasus", "BigBirdPegasus"),
("deit", "DeiT"),
("luke", "LUKE"),
("gpt_neo", "GPT Neo"),
......
......@@ -59,6 +59,13 @@ from ..big_bird.modeling_big_bird import (
BigBirdForTokenClassification,
BigBirdModel,
)
from ..bigbird_pegasus.modeling_bigbird_pegasus import (
BigBirdPegasusForCausalLM,
BigBirdPegasusForConditionalGeneration,
BigBirdPegasusForQuestionAnswering,
BigBirdPegasusForSequenceClassification,
BigBirdPegasusModel,
)
from ..blenderbot.modeling_blenderbot import BlenderbotForCausalLM, BlenderbotForConditionalGeneration, BlenderbotModel
from ..blenderbot_small.modeling_blenderbot_small import (
BlenderbotSmallForCausalLM,
......@@ -288,6 +295,7 @@ from .configuration_auto import (
BertConfig,
BertGenerationConfig,
BigBirdConfig,
BigBirdPegasusConfig,
BlenderbotConfig,
BlenderbotSmallConfig,
CamembertConfig,
......@@ -344,6 +352,7 @@ logger = logging.get_logger(__name__)
MODEL_MAPPING = OrderedDict(
[
# Base model mapping
(BigBirdPegasusConfig, BigBirdPegasusModel),
(DeiTConfig, DeiTModel),
(LukeConfig, LukeModel),
(GPTNeoConfig, GPTNeoModel),
......@@ -439,6 +448,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
[
# Model with LM heads mapping
(BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration),
(GPTNeoConfig, GPTNeoForCausalLM),
(BigBirdConfig, BigBirdForMaskedLM),
(Speech2TextConfig, Speech2TextForConditionalGeneration),
......@@ -485,6 +495,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
[
# Model for Causal LM mapping
(BigBirdPegasusConfig, BigBirdPegasusForCausalLM),
(GPTNeoConfig, GPTNeoForCausalLM),
(BigBirdConfig, BigBirdForCausalLM),
(CamembertConfig, CamembertForCausalLM),
......@@ -557,6 +568,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
(BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration),
(M2M100Config, M2M100ForConditionalGeneration),
(LEDConfig, LEDForConditionalGeneration),
(BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration),
......@@ -577,6 +589,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[
# Model for Sequence Classification mapping
(BigBirdPegasusConfig, BigBirdPegasusForSequenceClassification),
(BigBirdConfig, BigBirdForSequenceClassification),
(ConvBertConfig, ConvBertForSequenceClassification),
(LEDConfig, LEDForSequenceClassification),
......@@ -614,6 +627,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
[
# Model for Question Answering mapping
(BigBirdPegasusConfig, BigBirdPegasusForQuestionAnswering),
(BigBirdConfig, BigBirdForQuestionAnswering),
(ConvBertConfig, ConvBertForQuestionAnswering),
(LEDConfig, LEDForQuestionAnswering),
......
......@@ -549,6 +549,7 @@ class BigBirdBlockSparseAttention(nn.Module):
rsqrt_d = 1 / math.sqrt(attention_head_size)
bsz = batch_size
attn_mask_penalty = -10000.0
# generate random attention and corresponding masks
np.random.seed(seed)
......@@ -606,7 +607,7 @@ class BigBirdBlockSparseAttention(nn.Module):
first_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, 0], key_layer, ndim=4)
first_product = first_product * rsqrt_d
first_product += (1.0 - to_mask) * -10000.0
first_product += (1.0 - to_mask) * attn_mask_penalty
first_attn_weights = F.softmax(first_product, dim=-1) # [bsz, n_heads, from_block_size, to_seq_len]
# [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1]
......@@ -658,7 +659,7 @@ class BigBirdBlockSparseAttention(nn.Module):
dim=3,
)
second_product = second_product * rsqrt_d
second_product += (1.0 - torch.minimum(second_seq_pad, second_rand_pad)) * -10000.0
second_product += (1.0 - torch.minimum(second_seq_pad, second_rand_pad)) * attn_mask_penalty
second_attn_weights = F.softmax(
second_product, dim=-1
) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size]
......@@ -709,10 +710,10 @@ class BigBirdBlockSparseAttention(nn.Module):
last_band_product = last_band_product * rsqrt_d
# masking padded tokens
inner_band_product += (1.0 - band_mask) * -10000.0
first_band_product += (1.0 - to_mask[:, :, :, :to_block_size].unsqueeze(3)) * -10000.0
last_band_product += (1.0 - to_mask[:, :, :, -to_block_size:].unsqueeze(3)) * -10000.0
rand_band_product += (1.0 - rand_mask[:, :, 1:-1]) * -10000.0
inner_band_product += (1.0 - band_mask) * attn_mask_penalty
first_band_product += (1.0 - to_mask[:, :, :, :to_block_size].unsqueeze(3)) * attn_mask_penalty
last_band_product += (1.0 - to_mask[:, :, :, -to_block_size:].unsqueeze(3)) * attn_mask_penalty
rand_band_product += (1.0 - rand_mask[:, :, 1:-1]) * attn_mask_penalty
# completing attention scores matrix for all q[-2:2]
band_product = torch.cat(
......@@ -792,7 +793,7 @@ class BigBirdBlockSparseAttention(nn.Module):
dim=3,
)
second_last_product = second_last_product * rsqrt_d
second_last_product += (1.0 - torch.minimum(second_last_seq_pad, second_last_rand_pad)) * -10000.0
second_last_product += (1.0 - torch.minimum(second_last_seq_pad, second_last_rand_pad)) * attn_mask_penalty
second_last_attn_weights = F.softmax(
second_last_product, dim=-1
) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size]
......@@ -808,7 +809,7 @@ class BigBirdBlockSparseAttention(nn.Module):
# [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, to_seq_len]
last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -1], key_layer, ndim=4)
last_product = last_product * rsqrt_d
last_product += (1.0 - to_mask) * -10000.0
last_product += (1.0 - to_mask) * attn_mask_penalty
last_attn_weights = F.softmax(last_product, dim=-1) # [bsz, n_heads, from_block_size, n]
# [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1]
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_torch_available
_import_structure = {
"configuration_bigbird_pegasus": ["BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdPegasusConfig"],
}
if is_torch_available():
_import_structure["modeling_bigbird_pegasus"] = [
"BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST",
"BigBirdPegasusForCausalLM",
"BigBirdPegasusForConditionalGeneration",
"BigBirdPegasusForQuestionAnswering",
"BigBirdPegasusForSequenceClassification",
"BigBirdPegasusModel",
"BigBirdPegasusPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_bigbird_pegasus import BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdPegasusConfig
if is_torch_available():
from .modeling_bigbird_pegasus import (
BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST,
BigBirdPegasusForCausalLM,
BigBirdPegasusForConditionalGeneration,
BigBirdPegasusForQuestionAnswering,
BigBirdPegasusForSequenceClassification,
BigBirdPegasusModel,
BigBirdPegasusPreTrainedModel,
)
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
# coding=utf-8
# Copyright Google Research and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BigBirdPegasus model configuration """
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/bigbird-pegasus-large-arxiv": "https://huggingface.co/google/bigbird-pegasus-large-arxiv/resolve/main/config.json",
"google/bigbird-pegasus-large-pubmed": "https://huggingface.co/google/bigbird-pegasus-large-pubmed/resolve/main/config.json",
"google/bigbird-pegasus-large-bigpatent": "https://huggingface.co/google/bigbird-pegasus-large-bigpatent/resolve/main/config.json",
# See all BigBirdPegasus models at https://huggingface.co/models?filter=bigbird_pegasus
}
class BigBirdPegasusConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.BigBirdPegasusModel`. It is
used to instantiate an BigBirdPegasus model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the BigBirdPegasus
`google/bigbird-pegasus-large-arxiv <https://huggingface.co/google/bigbird-pegasus-large-arxiv>`__ architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args:
vocab_size (:obj:`int`, `optional`, defaults to 96103):
Vocabulary size of the BigBirdPegasus model. Defines the number of different tokens that can be represented
by the :obj:`inputs_ids` passed when calling :class:`~transformers.BigBirdPegasusModel`.
d_model (:obj:`int`, `optional`, defaults to 1024):
Dimension of the layers and the pooler layer.
encoder_layers (:obj:`int`, `optional`, defaults to 16):
Number of encoder layers.
decoder_layers (:obj:`int`, `optional`, defaults to 16):
Number of decoder layers.
encoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu_fast"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"`, :obj:`"gelu_fast"` and :obj:`"gelu_new"` are supported.
dropout (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout ratio for classifier.
max_position_embeddings (:obj:`int`, `optional`, defaults to 4096):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 1024 or 2048 or 4096).
init_std (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the encoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details.
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the decoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
attention_type (:obj:`str`, `optional`, defaults to :obj:`"block_sparse"`)
Whether to use block sparse attention (with n complexity) as introduced in paper or original attention
layer (with n^2 complexity) in encoder. Possible values are :obj:`"original_full"` and
:obj:`"block_sparse"`.
use_bias (:obj:`bool`, `optional`, defaults to :obj:`False`)
Whether to use bias in query, key, value.
block_size (:obj:`int`, `optional`, defaults to 64)
Size of each block. Useful only when :obj:`attention_type == "block_sparse"`.
num_random_blocks (:obj:`int`, `optional`, defaults to 3)
Each query is going to attend these many number of random blocks. Useful only when :obj:`attention_type ==
"block_sparse"`.
scale_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`)
Whether to rescale embeddings with (hidden_size ** 0.5).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example::
>>> from transformers import BigBirdPegasusModel, BigBirdPegasusConfig
>>> # Initializing a BigBirdPegasus bigbird-pegasus-base style configuration
>>> configuration = BigBirdPegasusConfig()
>>> # Initializing a model from the bigbird-pegasus-base style configuration
>>> model = BigBirdPegasusModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "bigbird_pegasus"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=96103,
max_position_embeddings=4096,
encoder_layers=16,
encoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_layers=16,
decoder_ffn_dim=4096,
decoder_attention_heads=16,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
use_cache=True,
is_encoder_decoder=True,
activation_function="gelu_fast",
d_model=1024,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
decoder_start_token_id=2,
classifier_dropout=0.0,
scale_embedding=True,
gradient_checkpointing=False,
pad_token_id=0,
bos_token_id=2,
eos_token_id=1,
attention_type="block_sparse", # only for encoder
block_size=64,
num_random_blocks=3,
use_bias=False,
**kwargs
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
# extra config
self.attention_type = attention_type
self.block_size = block_size
self.num_random_blocks = num_random_blocks
self.use_bias = use_bias
@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads
@property
def hidden_size(self) -> int:
return self.d_model
@property
def attention_probs_dropout_prob(self) -> float:
return self.attention_dropout
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from typing import Dict
import tensorflow as tf
import torch
from tqdm import tqdm
from transformers import BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration
INIT_COMMON = [
# tf -> hf
("/", "."),
("layer_", "layers."),
("kernel", "weight"),
("beta", "bias"),
("gamma", "weight"),
("pegasus", "model"),
]
END_COMMON = [
(".output.dense", ".fc2"),
("intermediate.LayerNorm", "final_layer_norm"),
("intermediate.dense", "fc1"),
]
DECODER_PATTERNS = (
INIT_COMMON
+ [
("attention.self.LayerNorm", "self_attn_layer_norm"),
("attention.output.dense", "self_attn.out_proj"),
("attention.self", "self_attn"),
("attention.encdec.LayerNorm", "encoder_attn_layer_norm"),
("attention.encdec_output.dense", "encoder_attn.out_proj"),
("attention.encdec", "encoder_attn"),
("key", "k_proj"),
("value", "v_proj"),
("query", "q_proj"),
("decoder.LayerNorm", "decoder.layernorm_embedding"),
]
+ END_COMMON
)
REMAINING_PATTERNS = (
INIT_COMMON
+ [
("embeddings.word_embeddings", "shared.weight"),
("embeddings.position_embeddings", "embed_positions.weight"),
("attention.self.LayerNorm", "self_attn_layer_norm"),
("attention.output.dense", "self_attn.output"),
("attention.self", "self_attn.self"),
("encoder.LayerNorm", "encoder.layernorm_embedding"),
]
+ END_COMMON
)
KEYS_TO_IGNORE = [
"encdec/key/bias",
"encdec/query/bias",
"encdec/value/bias",
"self/key/bias",
"self/query/bias",
"self/value/bias",
"encdec_output/dense/bias",
"attention/output/dense/bias",
]
def rename_state_dict_key(k, patterns):
for tf_name, hf_name in patterns:
k = k.replace(tf_name, hf_name)
return k
def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPegasusForConditionalGeneration:
cfg = BigBirdPegasusConfig(**config_update)
torch_model = BigBirdPegasusForConditionalGeneration(cfg)
state_dict = torch_model.state_dict()
mapping = {}
# separating decoder weights
decoder_weights = {k: tf_weights[k] for k in tf_weights if k.startswith("pegasus/decoder")}
remaining_weights = {k: tf_weights[k] for k in tf_weights if not k.startswith("pegasus/decoder")}
for k, v in tqdm(decoder_weights.items(), "tf -> hf conversion"):
conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
if any(conditions):
continue
patterns = DECODER_PATTERNS
new_k = rename_state_dict_key(k, patterns)
if new_k not in state_dict:
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
if any([True if i in k else False for i in ["dense", "query", "key", "value"]]):
v = v.T
mapping[new_k] = torch.from_numpy(v)
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
for k, v in tqdm(remaining_weights.items(), "tf -> hf conversion"):
conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
if any(conditions):
continue
patterns = REMAINING_PATTERNS
new_k = rename_state_dict_key(k, patterns)
if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings":
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
if any([True if i in k else False for i in ["dense", "query", "key", "value"]]):
v = v.T
mapping[new_k] = torch.from_numpy(v)
if k != "pegasus/embeddings/position_embeddings":
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
mapping["model.encoder.embed_positions.weight"] = mapping["model.embed_positions.weight"]
mapping["model.decoder.embed_positions.weight"] = mapping.pop("model.embed_positions.weight")
missing, extra = torch_model.load_state_dict(mapping, strict=False)
unexpected_missing = [
k
for k in missing
if k
not in [
"final_logits_bias",
"model.encoder.embed_tokens.weight",
"model.decoder.embed_tokens.weight",
"lm_head.weight",
]
]
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
assert extra == [], f"no matches found for the following tf keys {extra}"
return torch_model
def get_tf_weights_as_numpy(path) -> Dict:
init_vars = tf.train.list_variables(path)
tf_weights = {}
ignore_name = ["global_step"]
for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
skip_key = any([pat in name for pat in ignore_name])
if skip_key:
continue
array = tf.train.load_variable(path, name)
tf_weights[name] = array
return tf_weights
def convert_bigbird_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str, config_update: dict):
tf_weights = get_tf_weights_as_numpy(ckpt_path)
torch_model = convert_bigbird_pegasus(tf_weights, config_update)
torch_model.save_pretrained(save_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tf_ckpt_path", type=str, help="passed to tf.train.list_variables")
parser.add_argument("--save_dir", default=None, type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()
config_update = {}
convert_bigbird_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir, config_update=config_update)
......@@ -80,7 +80,6 @@ class PegasusTokenizer(PreTrainedTokenizer):
"""
vocab_files_names = VOCAB_FILES_NAMES
offset = 103 # entries 2 - 104 are only used for pretraining
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
......@@ -95,8 +94,11 @@ class PegasusTokenizer(PreTrainedTokenizer):
mask_token="<mask_2>",
mask_token_sent="<mask_1>",
additional_special_tokens=None,
offset=103, # entries 2 - 104 are only used for pretraining
**kwargs
):
self.offset = offset
if additional_special_tokens is not None:
assert isinstance(
additional_special_tokens, list
......@@ -104,7 +106,7 @@ class PegasusTokenizer(PreTrainedTokenizer):
additional_special_tokens_extended = (
([mask_token_sent] + additional_special_tokens)
if mask_token_sent not in additional_special_tokens
if mask_token_sent not in additional_special_tokens and mask_token_sent is not None
else additional_special_tokens
)
# fill additional tokens with ..., <unk_token_102> in case not all additional tokens are already taken
......@@ -118,7 +120,7 @@ class PegasusTokenizer(PreTrainedTokenizer):
)
additional_special_tokens = additional_special_tokens_extended
else:
additional_special_tokens = [mask_token_sent]
additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else []
additional_special_tokens += [f"<unk_{i}>" for i in range(2, self.offset)]
super().__init__(
......@@ -127,24 +129,34 @@ class PegasusTokenizer(PreTrainedTokenizer):
mask_token=mask_token,
pad_token=pad_token,
mask_token_sent=mask_token_sent,
offset=offset,
additional_special_tokens=additional_special_tokens,
**kwargs,
)
self.mask_token_sent = mask_token_sent
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(vocab_file)
self.mask_token_sent = mask_token_sent
# add special tokens to encoder dict
self.encoder: Dict[int, str] = {
0: self.pad_token,
1: self.eos_token,
2: self.mask_token_sent,
3: self.mask_token,
}
# entries 2-104 are only used for pretraining and called <mask_1>, <mask_2>, unk_2, ...unk_102
# mask_token_sent is already added to list -> so start at 1
self.encoder.update({i + 3: additional_special_tokens[i] for i in range(1, self.offset - 1)})
if self.mask_token_sent is not None:
self.encoder.update(
{
2: self.mask_token_sent,
3: self.mask_token,
}
)
if self.offset > 0:
# entries 2-104 are only used for pretraining and called <mask_1>, <mask_2>, unk_2, ...unk_102
# mask_token_sent is already added to list -> so start at 1
self.encoder.update({i + 3: additional_special_tokens[i] for i in range(1, self.offset - 1)})
self.decoder: Dict[str, int] = {v: k for k, v in self.encoder.items()}
@property
......@@ -206,10 +218,6 @@ class PegasusTokenizer(PreTrainedTokenizer):
all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
all_special_ids.remove(self.unk_token_id) # <unk> is only sometimes special
assert all_special_ids == set(
range(len(self.additional_special_tokens) + 3)
), f"There should be 3 special tokens: mask_token, pad_token, and eos_token + {len(self.additional_special_tokens)} additional_special_tokens, but got {all_special_ids}"
return [1 if x in all_special_ids else 0 for x in seq]
def get_special_tokens_mask(
......
......@@ -90,7 +90,6 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast):
<https://github.com/google-research/pegasus/blob/939830367bcf411193d2b5eca2f2f90f3f9260ca/pegasus/ops/pretrain_parsing_ops.cc#L66>`__
that uses the tokens 2 - 104 only for pretraining
"""
offset = 103 # entries 2-104 are only used for pretraining
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
......@@ -107,8 +106,11 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast):
mask_token="<mask_2>",
mask_token_sent="<mask_1>",
additional_special_tokens=None,
offset=103, # entries 2 - 104 are only used for pretraining
**kwargs
):
self.offset = offset
if additional_special_tokens is not None:
assert isinstance(
additional_special_tokens, list
......@@ -116,7 +118,7 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast):
additional_special_tokens_extended = (
([mask_token_sent] + additional_special_tokens)
if mask_token_sent not in additional_special_tokens
if mask_token_sent not in additional_special_tokens and mask_token_sent is not None
else additional_special_tokens
)
# fill additional tokens with ..., <unk_token_102> in case not all additional tokens are already taken
......@@ -130,7 +132,7 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast):
)
additional_special_tokens = additional_special_tokens_extended
else:
additional_special_tokens = [mask_token_sent]
additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else []
additional_special_tokens += [f"<unk_{i}>" for i in range(2, self.offset)]
super().__init__(
......@@ -141,10 +143,10 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast):
unk_token=unk_token,
mask_token=mask_token,
mask_token_sent=mask_token_sent,
offset=offset,
additional_special_tokens=additional_special_tokens,
**kwargs,
)
self.vocab_file = vocab_file
def _special_token_mask(self, seq):
......
......@@ -721,6 +721,50 @@ def load_tf_weights_in_big_bird(*args, **kwargs):
requires_backends(load_tf_weights_in_big_bird, ["torch"])
BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST = None
class BigBirdPegasusForCausalLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BigBirdPegasusForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BigBirdPegasusForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BigBirdPegasusForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BigBirdPegasusModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["torch"])
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
......@@ -6,6 +6,7 @@ from collections import OrderedDict
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("BigBirdPegasusConfig", "BigBirdPegasusForQuestionAnswering"),
("BigBirdConfig", "BigBirdForQuestionAnswering"),
("ConvBertConfig", "ConvBertForQuestionAnswering"),
("LEDConfig", "LEDForQuestionAnswering"),
......
......@@ -310,19 +310,18 @@ class GenerationTesterMixin:
logits_processor.append(InfNanRemoveLogitsProcessor())
with torch.no_grad():
with torch.no_grad():
output_sample = model.sample(
input_ids_clone,
attention_mask=attention_mask_clone,
max_length=max_length,
logits_processor=logits_processor,
logits_warper=logits_warper,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
)
output_sample = model.sample(
input_ids_clone,
attention_mask=attention_mask_clone,
max_length=max_length,
logits_processor=logits_processor,
logits_warper=logits_warper,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
)
return output_sample, output_generate
def _beam_search_generate(
......
This diff is collapsed.
......@@ -1043,7 +1043,6 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as temp_dir_name:
model.base_model.save_pretrained(temp_dir_name)
model, loading_info = model_class.from_pretrained(temp_dir_name, output_loading_info=True)
with self.subTest(msg=f"Missing keys for {model.__class__.__name__}"):
self.assertGreater(len(loading_info["missing_keys"]), 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment