"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "361620954acf16b27727d763a591257b03f90b5d"
Unverified Commit 925f34bb authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add "tie_word_embeddings" config param (#6692)

* add tie_word_embeddings

* correct word embeddings in modeling utils

* make style

* make config param only relevant for torch

* make style

* correct typo

* delete deprecated arg in transo-xl
parent fa8ee8e8
...@@ -165,9 +165,16 @@ class ReformerConfig(PretrainedConfig): ...@@ -165,9 +165,16 @@ class ReformerConfig(PretrainedConfig):
num_hashes=1, num_hashes=1,
pad_token_id=0, pad_token_id=0,
vocab_size=320, vocab_size=320,
tie_word_embeddings=False,
**kwargs **kwargs
): ):
super().__init__(pad_token_id=pad_token_id, eos_token_id=eos_token_id, is_decoder=is_decoder, **kwargs) super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
is_decoder=is_decoder,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.hash_seed = hash_seed self.hash_seed = hash_seed
self.vocab_size = vocab_size self.vocab_size = vocab_size
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import logging import logging
import warnings
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
...@@ -79,8 +80,6 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -79,8 +80,6 @@ class TransfoXLConfig(PretrainedConfig):
number of samples in sampled softmax number of samples in sampled softmax
adaptive (:obj:`boolean`, optional, defaults to :obj:`True`): adaptive (:obj:`boolean`, optional, defaults to :obj:`True`):
use adaptive softmax use adaptive softmax
tie_weight (:obj:`boolean`, optional, defaults to :obj:`True`):
tie the word embedding and softmax weights
dropout (:obj:`float`, optional, defaults to 0.1): dropout (:obj:`float`, optional, defaults to 0.1):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
dropatt (:obj:`float`, optional, defaults to 0): dropatt (:obj:`float`, optional, defaults to 0):
...@@ -135,7 +134,6 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -135,7 +134,6 @@ class TransfoXLConfig(PretrainedConfig):
attn_type=0, attn_type=0,
sample_softmax=-1, sample_softmax=-1,
adaptive=True, adaptive=True,
tie_weight=True,
dropout=0.1, dropout=0.1,
dropatt=0.0, dropatt=0.0,
untie_r=True, untie_r=True,
...@@ -147,12 +145,17 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -147,12 +145,17 @@ class TransfoXLConfig(PretrainedConfig):
eos_token_id=0, eos_token_id=0,
**kwargs **kwargs
): ):
super().__init__(eos_token_id=eos_token_id, **kwargs) if "tie_weight" in kwargs:
warnings.warn(
"The config parameter `tie_weight` is deprecated. Please use `tie_word_embeddings` instead.",
FutureWarning,
)
kwargs["tie_word_embeddings"] = kwargs["tie_weight"]
super().__init__(eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.cutoffs = [] self.cutoffs = []
self.cutoffs.extend(cutoffs) self.cutoffs.extend(cutoffs)
self.tie_weight = tie_weight
if proj_share_all_but_first: if proj_share_all_but_first:
self.tie_projs = [False] + [True] * len(self.cutoffs) self.tie_projs = [False] + [True] * len(self.cutoffs)
else: else:
......
...@@ -134,6 +134,7 @@ class PretrainedConfig(object): ...@@ -134,6 +134,7 @@ class PretrainedConfig(object):
PyTorch specific parameters PyTorch specific parameters
- **torchscript** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should be - **torchscript** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should be
used with Torchscript. used with Torchscript.
- **tie_word_embeddings** (:obj:`bool`, `optional`, defaults to :obj:`True`) -- Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the model has a output word embedding layer.
TensorFlow specific parameters TensorFlow specific parameters
- **use_bfloat16** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should - **use_bfloat16** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should
...@@ -150,6 +151,9 @@ class PretrainedConfig(object): ...@@ -150,6 +151,9 @@ class PretrainedConfig(object):
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop("use_bfloat16", False) self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
self.pruned_heads = kwargs.pop("pruned_heads", {}) self.pruned_heads = kwargs.pop("pruned_heads", {})
self.tie_word_embeddings = kwargs.pop(
"tie_word_embeddings", True
) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False) self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
......
...@@ -647,14 +647,13 @@ class AlbertForPreTraining(AlbertPreTrainedModel): ...@@ -647,14 +647,13 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
self.sop_classifier = AlbertSOPHead(config) self.sop_classifier = AlbertSOPHead(config)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self):
self._tie_or_clone_weights(self.predictions.decoder, self.albert.embeddings.word_embeddings)
def get_output_embeddings(self): def get_output_embeddings(self):
return self.predictions.decoder return self.predictions.decoder
def get_input_embeddings(self):
return self.albert.embeddings.word_embeddings
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=AlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
...@@ -798,14 +797,13 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): ...@@ -798,14 +797,13 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
self.predictions = AlbertMLMHead(config) self.predictions = AlbertMLMHead(config)
self.init_weights() self.init_weights()
self.tie_weights()
def tie_weights(self):
self._tie_or_clone_weights(self.predictions.decoder, self.albert.embeddings.word_embeddings)
def get_output_embeddings(self): def get_output_embeddings(self):
return self.predictions.decoder return self.predictions.decoder
def get_input_embeddings(self):
return self.albert.embeddings.word_embeddings
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -945,7 +945,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel): ...@@ -945,7 +945,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
self.cls.predictions.dense = resized_dense self.cls.predictions.dense = resized_dense
self.cls.predictions.dense.to(self.device) self.cls.predictions.dense.to(self.device)
if output_embeddings is not None: if output_embeddings is not None and self.config.tie_word_embeddings:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
...@@ -1060,7 +1060,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel): ...@@ -1060,7 +1060,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
self.cls.predictions.dense = resized_dense self.cls.predictions.dense = resized_dense
self.cls.predictions.dense.to(self.device) self.cls.predictions.dense.to(self.device)
if output_embeddings is not None: if output_embeddings is not None and self.config.tie_word_embeddings:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
......
...@@ -2155,10 +2155,6 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): ...@@ -2155,10 +2155,6 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head.decoder return self.lm_head.decoder
def tie_weights(self):
# word embeddings are not tied in Reformer
pass
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
...@@ -2274,10 +2270,6 @@ class ReformerForMaskedLM(ReformerPreTrainedModel): ...@@ -2274,10 +2270,6 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head.decoder return self.lm_head.decoder
def tie_weights(self):
# word embeddings are not tied in Reformer
pass
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
...@@ -2356,10 +2348,6 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel): ...@@ -2356,10 +2348,6 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
self.init_weights() self.init_weights()
def tie_weights(self):
# word embeddings are not tied in Reformer
pass
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
...@@ -2459,10 +2447,6 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel): ...@@ -2459,10 +2447,6 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
self.init_weights() self.init_weights()
def tie_weights(self):
# word embeddings are not tied in Reformer
pass
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -62,7 +62,7 @@ def build_tf_to_pytorch_map(model, config): ...@@ -62,7 +62,7 @@ def build_tf_to_pytorch_map(model, config):
zip(model.crit.out_layers, model.crit.out_projs, config.tie_projs) zip(model.crit.out_layers, model.crit.out_projs, config.tie_projs)
): ):
layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
if config.tie_weight: if config.tie_word_embeddings:
tf_to_pt_map.update({layer_str + "b": out_l.bias}) tf_to_pt_map.update({layer_str + "b": out_l.bias})
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -978,7 +978,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -978,7 +978,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
Run this to be sure output and input (adaptive) softmax weights are tied Run this to be sure output and input (adaptive) softmax weights are tied
""" """
if self.config.tie_weight: if self.config.tie_word_embeddings:
for i in range(len(self.crit.out_layers)): for i in range(len(self.crit.out_layers)):
self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i]) self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i])
if self.config.tie_projs: if self.config.tie_projs:
......
...@@ -413,7 +413,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -413,7 +413,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
the weights instead. the weights instead.
""" """
output_embeddings = self.get_output_embeddings() output_embeddings = self.get_output_embeddings()
if output_embeddings is not None: if output_embeddings is not None and self.config.tie_word_embeddings:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
if self.config.is_encoder_decoder and self.config.tie_encoder_decoder: if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment