Unverified Commit 2a6fbe6a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[XLNet] Fix mems behavior (#8567)

* fix mems in xlnet

* fix use_mems

* fix use_mem_len

* fix use mems

* clean docs

* fix tf typo

* make xlnet tf for generation work

* fix tf test

* refactor use cache

* add use cache for missing models

* correct use_cache in generate

* correct use cache in tf generate

* fix tf

* correct getattr typo

* make sylvain happy

* change in docs as well

* do not apply to cookie cutter statements

* fix tf test

* make pytorch model fully backward compatible
parent 369f1d77
...@@ -809,7 +809,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel): ...@@ -809,7 +809,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
""" """
Albert Model with two heads on top for pre-training: a `masked language modeling` head and a `sentence order Albert Model with two heads on top for pretraining: a `masked language modeling` head and a `sentence order
prediction` (classification) head. prediction` (classification) head.
""", """,
ALBERT_START_DOCSTRING, ALBERT_START_DOCSTRING,
......
...@@ -108,6 +108,8 @@ class BartConfig(PretrainedConfig): ...@@ -108,6 +108,8 @@ class BartConfig(PretrainedConfig):
force_bos_token_to_be_generated (:obj:`bool`, `optional`, defaults to :obj:`False`): force_bos_token_to_be_generated (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force BOS token to be generated at step 1 (after ``decoder_start_token_id``), only Whether or not to force BOS token to be generated at step 1 (after ``decoder_start_token_id``), only
:obj:`True` for `bart-large-cnn`. :obj:`True` for `bart-large-cnn`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
""" """
model_type = "bart" model_type = "bart"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
...@@ -134,9 +136,6 @@ class BartConfig(PretrainedConfig): ...@@ -134,9 +136,6 @@ class BartConfig(PretrainedConfig):
classifier_dropout=0.0, classifier_dropout=0.0,
num_labels=3, num_labels=3,
is_encoder_decoder=True, is_encoder_decoder=True,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
normalize_before=False, normalize_before=False,
add_final_layer_norm=False, add_final_layer_norm=False,
do_blenderbot_90_layernorm=False, do_blenderbot_90_layernorm=False,
...@@ -145,6 +144,10 @@ class BartConfig(PretrainedConfig): ...@@ -145,6 +144,10 @@ class BartConfig(PretrainedConfig):
static_position_embeddings=False, static_position_embeddings=False,
add_bias_logits=False, add_bias_logits=False,
force_bos_token_to_be_generated=False, force_bos_token_to_be_generated=False,
use_cache=True,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
**common_kwargs **common_kwargs
): ):
r""" r"""
...@@ -208,6 +211,8 @@ class BartConfig(PretrainedConfig): ...@@ -208,6 +211,8 @@ class BartConfig(PretrainedConfig):
self.do_blenderbot_90_layernorm = do_blenderbot_90_layernorm self.do_blenderbot_90_layernorm = do_blenderbot_90_layernorm
self.use_cache = use_cache
@property @property
def num_attention_heads(self) -> int: def num_attention_heads(self) -> int:
return self.encoder_attention_heads return self.encoder_attention_heads
......
...@@ -888,7 +888,7 @@ class BertModel(BertPreTrainedModel): ...@@ -888,7 +888,7 @@ class BertModel(BertPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
""" """
Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and a `next Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
sentence prediction (classification)` head. sentence prediction (classification)` head.
""", """,
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
......
...@@ -90,7 +90,7 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -90,7 +90,7 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
class TFBertPreTrainingLoss: class TFBertPreTrainingLoss:
""" """
Loss function suitable for BERT-like pre-training, that is, the task of pretraining a language model by combining Loss function suitable for BERT-like pretraining, that is, the task of pretraining a language model by combining
NSP + MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss NSP + MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss
computation. computation.
""" """
...@@ -878,7 +878,7 @@ class TFBertModel(TFBertPreTrainedModel): ...@@ -878,7 +878,7 @@ class TFBertModel(TFBertPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
""" """
Bert Model with two heads on top as done during the pre-training: Bert Model with two heads on top as done during the pretraining:
a `masked language modeling` head and a `next sentence prediction (classification)` head. a `masked language modeling` head and a `next sentence prediction (classification)` head.
""", """,
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
......
...@@ -80,7 +80,7 @@ class BertweetTokenizer(PreTrainedTokenizer): ...@@ -80,7 +80,7 @@ class BertweetTokenizer(PreTrainedTokenizer):
normalization (:obj:`bool`, `optional`, defaults to :obj:`False`) normalization (:obj:`bool`, `optional`, defaults to :obj:`False`)
Whether or not to apply a normalization preprocess. Whether or not to apply a normalization preprocess.
bos_token (:obj:`str`, `optional`, defaults to :obj:`"<s>"`): bos_token (:obj:`str`, `optional`, defaults to :obj:`"<s>"`):
The beginning of sequence token that was used during pre-training. Can be used a sequence classifier token. The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
.. note:: .. note::
......
...@@ -61,6 +61,9 @@ class CTRLConfig(PretrainedConfig): ...@@ -61,6 +61,9 @@ class CTRLConfig(PretrainedConfig):
The epsilon to use in the layer normalization layers The epsilon to use in the layer normalization layers
initializer_range (:obj:`float`, `optional`, defaults to 0.02): initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
Examples:: Examples::
...@@ -98,6 +101,7 @@ class CTRLConfig(PretrainedConfig): ...@@ -98,6 +101,7 @@ class CTRLConfig(PretrainedConfig):
summary_activation=None, summary_activation=None,
summary_proj_to_labels=True, summary_proj_to_labels=True,
summary_first_dropout=0.1, summary_first_dropout=0.1,
use_cache=True,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -119,6 +123,7 @@ class CTRLConfig(PretrainedConfig): ...@@ -119,6 +123,7 @@ class CTRLConfig(PretrainedConfig):
self.summary_activation = summary_activation self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels self.summary_proj_to_labels = summary_proj_to_labels
self.use_cache = use_cache
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):
......
...@@ -772,7 +772,7 @@ DEBERTA_START_DOCSTRING = r""" ...@@ -772,7 +772,7 @@ DEBERTA_START_DOCSTRING = r"""
The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention
<https://arxiv.org/abs/2006.03654>`_ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build on top of <https://arxiv.org/abs/2006.03654>`_ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build on top of
BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two
improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pre-training data. improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data.
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
......
...@@ -891,8 +891,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel): ...@@ -891,8 +891,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
""" """
Electra model with a binary classification head on top as used during pre-training for identifying generated Electra model with a binary classification head on top as used during pretraining for identifying generated tokens.
tokens.
It is recommended to load the discriminator checkpoint into that model. It is recommended to load the discriminator checkpoint into that model.
""", """,
......
...@@ -789,8 +789,7 @@ class TFElectraModel(TFElectraPreTrainedModel): ...@@ -789,8 +789,7 @@ class TFElectraModel(TFElectraPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
""" """
Electra model with a binary classification head on top as used during pre-training for identifying generated Electra model with a binary classification head on top as used during pretraining for identifying generated tokens.
tokens.
Even though both the discriminator and generator may be loaded into this model, the discriminator is the only model Even though both the discriminator and generator may be loaded into this model, the discriminator is the only model
of the two to have the correct classification head to be used for this model. of the two to have the correct classification head to be used for this model.
......
...@@ -109,6 +109,8 @@ class FSMTConfig(PretrainedConfig): ...@@ -109,6 +109,8 @@ class FSMTConfig(PretrainedConfig):
early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`) early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`)
Flag that will be used by default in the :obj:`generate` method of the model. Whether to stop the beam Flag that will be used by default in the :obj:`generate` method of the model. Whether to stop the beam
search when at least ``num_beams`` sentences are finished per batch or not. search when at least ``num_beams`` sentences are finished per batch or not.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
Examples:: Examples::
...@@ -142,9 +144,6 @@ class FSMTConfig(PretrainedConfig): ...@@ -142,9 +144,6 @@ class FSMTConfig(PretrainedConfig):
dropout=0.1, dropout=0.1,
activation_dropout=0.0, activation_dropout=0.0,
init_std=0.02, init_std=0.02,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
decoder_start_token_id=2, decoder_start_token_id=2,
is_encoder_decoder=True, is_encoder_decoder=True,
scale_embedding=True, scale_embedding=True,
...@@ -152,6 +151,10 @@ class FSMTConfig(PretrainedConfig): ...@@ -152,6 +151,10 @@ class FSMTConfig(PretrainedConfig):
num_beams=5, num_beams=5,
length_penalty=1.0, length_penalty=1.0,
early_stopping=False, early_stopping=False,
use_cache=True,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
**common_kwargs **common_kwargs
): ):
if "hidden_size" in common_kwargs: if "hidden_size" in common_kwargs:
...@@ -196,6 +199,8 @@ class FSMTConfig(PretrainedConfig): ...@@ -196,6 +199,8 @@ class FSMTConfig(PretrainedConfig):
self.activation_dropout = activation_dropout self.activation_dropout = activation_dropout
self.dropout = dropout self.dropout = dropout
self.use_cache = use_cache
@property @property
def num_attention_heads(self) -> int: def num_attention_heads(self) -> int:
return self.encoder_attention_heads return self.encoder_attention_heads
......
...@@ -1241,7 +1241,7 @@ class TFFunnelModel(TFFunnelPreTrainedModel): ...@@ -1241,7 +1241,7 @@ class TFFunnelModel(TFFunnelPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
""" """
Funnel model with a binary classification head on top as used during pre-training for identifying generated tokens. Funnel model with a binary classification head on top as used during pretraining for identifying generated tokens.
""", """,
FUNNEL_START_DOCSTRING, FUNNEL_START_DOCSTRING,
) )
......
...@@ -104,6 +104,8 @@ class GPT2Config(PretrainedConfig): ...@@ -104,6 +104,8 @@ class GPT2Config(PretrainedConfig):
The dropout ratio to be used after the projection and activation. The dropout ratio to be used after the projection and activation.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass. Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
Example:: Example::
...@@ -142,9 +144,10 @@ class GPT2Config(PretrainedConfig): ...@@ -142,9 +144,10 @@ class GPT2Config(PretrainedConfig):
summary_activation=None, summary_activation=None,
summary_proj_to_labels=True, summary_proj_to_labels=True,
summary_first_dropout=0.1, summary_first_dropout=0.1,
gradient_checkpointing=False,
use_cache=True,
bos_token_id=50256, bos_token_id=50256,
eos_token_id=50256, eos_token_id=50256,
gradient_checkpointing=False,
**kwargs **kwargs
): ):
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
...@@ -168,6 +171,7 @@ class GPT2Config(PretrainedConfig): ...@@ -168,6 +171,7 @@ class GPT2Config(PretrainedConfig):
self.summary_first_dropout = summary_first_dropout self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels self.summary_proj_to_labels = summary_proj_to_labels
self.gradient_checkpointing = gradient_checkpointing self.gradient_checkpointing = gradient_checkpointing
self.use_cache = use_cache
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
......
...@@ -1013,7 +1013,7 @@ class LxmertModel(LxmertPreTrainedModel): ...@@ -1013,7 +1013,7 @@ class LxmertModel(LxmertPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
"""Lxmert Model with a specified pre-training head on top. """, """Lxmert Model with a specified pretraining head on top. """,
LXMERT_START_DOCSTRING, LXMERT_START_DOCSTRING,
) )
class LxmertForPreTraining(LxmertPreTrainedModel): class LxmertForPreTraining(LxmertPreTrainedModel):
...@@ -1024,7 +1024,7 @@ class LxmertForPreTraining(LxmertPreTrainedModel): ...@@ -1024,7 +1024,7 @@ class LxmertForPreTraining(LxmertPreTrainedModel):
self.num_qa_labels = config.num_qa_labels self.num_qa_labels = config.num_qa_labels
self.visual_loss_normalizer = config.visual_loss_normalizer self.visual_loss_normalizer = config.visual_loss_normalizer
# Use of pre-training tasks # Use of pretraining tasks
self.task_mask_lm = config.task_mask_lm self.task_mask_lm = config.task_mask_lm
self.task_obj_predict = config.task_obj_predict self.task_obj_predict = config.task_obj_predict
self.task_matched = config.task_matched self.task_matched = config.task_matched
......
...@@ -1176,7 +1176,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1176,7 +1176,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
self.num_qa_labels = config.num_qa_labels self.num_qa_labels = config.num_qa_labels
self.visual_loss_normalizer = config.visual_loss_normalizer self.visual_loss_normalizer = config.visual_loss_normalizer
# Use of pre-training tasks # Use of pretraining tasks
self.task_mask_lm = config.task_mask_lm self.task_mask_lm = config.task_mask_lm
self.task_obj_predict = config.task_obj_predict self.task_obj_predict = config.task_obj_predict
self.task_matched = config.task_matched self.task_matched = config.task_matched
......
...@@ -933,7 +933,7 @@ class MobileBertModel(MobileBertPreTrainedModel): ...@@ -933,7 +933,7 @@ class MobileBertModel(MobileBertPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
""" """
MobileBert Model with two heads on top as done during the pre-training: a `masked language modeling` head and a MobileBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
`next sentence prediction (classification)` head. `next sentence prediction (classification)` head.
""", """,
MOBILEBERT_START_DOCSTRING, MOBILEBERT_START_DOCSTRING,
......
...@@ -1014,7 +1014,7 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel): ...@@ -1014,7 +1014,7 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
""" """
MobileBert Model with two heads on top as done during the pre-training: a `masked language modeling` head and a MobileBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
`next sentence prediction (classification)` head. `next sentence prediction (classification)` head.
""", """,
MOBILEBERT_START_DOCSTRING, MOBILEBERT_START_DOCSTRING,
......
...@@ -96,6 +96,9 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -96,6 +96,9 @@ class OpenAIGPTConfig(PretrainedConfig):
:class:`~transformers.OpenAIGPTDoubleHeadsModel` and :class:`~transformers.OpenAIGPTDoubleHeadsModel`. :class:`~transformers.OpenAIGPTDoubleHeadsModel` and :class:`~transformers.OpenAIGPTDoubleHeadsModel`.
The dropout ratio to be used after the projection and activation. The dropout ratio to be used after the projection and activation.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
Examples:: Examples::
...@@ -133,6 +136,7 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -133,6 +136,7 @@ class OpenAIGPTConfig(PretrainedConfig):
summary_activation=None, summary_activation=None,
summary_proj_to_labels=True, summary_proj_to_labels=True,
summary_first_dropout=0.1, summary_first_dropout=0.1,
use_cache=True,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -155,6 +159,7 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -155,6 +159,7 @@ class OpenAIGPTConfig(PretrainedConfig):
self.summary_activation = summary_activation self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels self.summary_proj_to_labels = summary_proj_to_labels
self.use_cache = use_cache
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):
......
...@@ -90,6 +90,8 @@ class ProphetNetConfig(PretrainedConfig): ...@@ -90,6 +90,8 @@ class ProphetNetConfig(PretrainedConfig):
eps (:obj:`float`, `optional`, defaults to 0.0): eps (:obj:`float`, `optional`, defaults to 0.0):
Controls the ``epsilon`` parameter value for label smoothing in the loss calculation. If set to 0, no label Controls the ``epsilon`` parameter value for label smoothing in the loss calculation. If set to 0, no label
smoothing is performed. smoothing is performed.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
""" """
model_type = "prophetnet" model_type = "prophetnet"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
...@@ -112,15 +114,16 @@ class ProphetNetConfig(PretrainedConfig): ...@@ -112,15 +114,16 @@ class ProphetNetConfig(PretrainedConfig):
init_std=0.02, init_std=0.02,
is_encoder_decoder=True, is_encoder_decoder=True,
add_cross_attention=True, add_cross_attention=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
decoder_start_token_id=0, decoder_start_token_id=0,
ngram=2, ngram=2,
num_buckets=32, num_buckets=32,
relative_max_distance=128, relative_max_distance=128,
disable_ngram_loss=False, disable_ngram_loss=False,
eps=0.0, eps=0.0,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
**kwargs **kwargs
): ):
super().__init__( super().__init__(
...@@ -156,6 +159,8 @@ class ProphetNetConfig(PretrainedConfig): ...@@ -156,6 +159,8 @@ class ProphetNetConfig(PretrainedConfig):
self.activation_dropout = activation_dropout self.activation_dropout = activation_dropout
self.dropout = dropout self.dropout = dropout
self.use_cache = use_cache
@property @property
def num_attention_heads(self) -> int: def num_attention_heads(self) -> int:
return self.num_encoder_attention_heads return self.num_encoder_attention_heads
......
...@@ -72,6 +72,8 @@ RAG_CONFIG_DOC = r""" ...@@ -72,6 +72,8 @@ RAG_CONFIG_DOC = r"""
output_retrieved(:obj:`bool`, `optional`, defaults to :obj:`False`): output_retrieved(:obj:`bool`, `optional`, defaults to :obj:`False`):
If set to ``True``, :obj:`retrieved_doc_embeds`, :obj:`retrieved_doc_ids`, :obj:`context_input_ids` and If set to ``True``, :obj:`retrieved_doc_embeds`, :obj:`retrieved_doc_ids`, :obj:`context_input_ids` and
:obj:`context_attention_mask` are returned. See returned tensors for more detail. :obj:`context_attention_mask` are returned. See returned tensors for more detail.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
""" """
...@@ -107,6 +109,7 @@ class RagConfig(PretrainedConfig): ...@@ -107,6 +109,7 @@ class RagConfig(PretrainedConfig):
exclude_bos_score=False, exclude_bos_score=False,
do_marginalize=False, do_marginalize=False,
output_retrieved=False, output_retrieved=False,
use_cache=True,
**kwargs **kwargs
): ):
super().__init__( super().__init__(
...@@ -156,6 +159,8 @@ class RagConfig(PretrainedConfig): ...@@ -156,6 +159,8 @@ class RagConfig(PretrainedConfig):
self.do_deduplication = do_deduplication self.do_deduplication = do_deduplication
self.use_cache = use_cache
@classmethod @classmethod
def from_question_encoder_generator_configs( def from_question_encoder_generator_configs(
cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs
......
...@@ -138,6 +138,8 @@ class ReformerConfig(PretrainedConfig): ...@@ -138,6 +138,8 @@ class ReformerConfig(PretrainedConfig):
:obj:`inputs_ids` passed when calling :class:`~transformers.ReformerModel`. :obj:`inputs_ids` passed when calling :class:`~transformers.ReformerModel`.
tie_word_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`): tie_word_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to tie input and output embeddings. Whether to tie input and output embeddings.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
Examples:: Examples::
...@@ -188,6 +190,7 @@ class ReformerConfig(PretrainedConfig): ...@@ -188,6 +190,7 @@ class ReformerConfig(PretrainedConfig):
pad_token_id=0, pad_token_id=0,
vocab_size=320, vocab_size=320,
tie_word_embeddings=False, tie_word_embeddings=False,
use_cache=True,
**kwargs **kwargs
): ):
super().__init__( super().__init__(
...@@ -226,3 +229,4 @@ class ReformerConfig(PretrainedConfig): ...@@ -226,3 +229,4 @@ class ReformerConfig(PretrainedConfig):
self.axial_norm_std = axial_norm_std self.axial_norm_std = axial_norm_std
self.chunk_size_lm_head = chunk_size_lm_head self.chunk_size_lm_head = chunk_size_lm_head
self.attn_layers = attn_layers self.attn_layers = attn_layers
self.use_cache = use_cache
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment