Unverified Commit 95f933ea authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Pretrained Model] Add resize_position_embeddings (#13559)

* finish

* delete bogus file

* correct some stuff

* finish

* finish
parent c783e148
...@@ -99,6 +99,13 @@ class ModelArguments: ...@@ -99,6 +99,13 @@ class ModelArguments:
"with private models)." "with private models)."
}, },
) )
resize_position_embeddings: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
"the model's position embeddings."
},
)
@dataclass @dataclass
...@@ -366,6 +373,25 @@ def main(): ...@@ -366,6 +373,25 @@ def main():
if model.config.decoder_start_token_id is None: if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
if (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings < data_args.max_source_length
):
if model_args.resize_position_embeddings is None:
logger.warning(
f"Increasing the model's number of position embedding vectors from {model.config.max_position_embedding} "
f"to {data_args.max_source_length}."
)
model.resize_position_embeddings(data_args.max_source_length)
elif model_args.resize_position_embeddings:
model.resize_position_embeddings(data_args.max_source_length)
else:
raise ValueError(
f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has {model.config.max_position_embeddings}"
f" position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically "
"resize the model's position encodings by passing `--resize_position_embeddings`."
)
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
# Preprocessing the datasets. # Preprocessing the datasets.
......
...@@ -887,6 +887,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -887,6 +887,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return new_lm_head return new_lm_head
def resize_position_embeddings(self, new_num_position_embeddings: int):
raise NotImplementedError(
f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
)
def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
raise NotImplementedError(
f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
)
def init_weights(self): def init_weights(self):
""" """
If needed prunes and maybe initializes weights. If needed prunes and maybe initializes weights.
......
...@@ -2833,7 +2833,7 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel): ...@@ -2833,7 +2833,7 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
return self.decoder(*args, **kwargs) return self.decoder(*args, **kwargs)
# Copied from transformers.models.pegasus.modeling_pegasus.PegasusForCausalLM with Pegasus->BigBirdPegasus, 'facebook/bart-large'->"google/bigbird-pegasus-large-arxiv" # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BigBirdPegasus, 'facebook/bart-large'->"google/bigbird-pegasus-large-arxiv"
class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel): class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -442,6 +442,67 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -442,6 +442,67 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.init_weights() self.init_weights()
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings
"""
return self.embeddings.position_embeddings
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embedding matrix. If position embeddings are learned, increasing the size
will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
size will add correct vectors at the end following the position encoding algorithm, whereas reducing
the size will remove vectors from the end.
"""
num_position_embeds_diff = new_num_position_embeddings - self.config.max_position_embeddings
# no resizing needs to be done if the length stays the same
if num_position_embeds_diff == 0:
return
logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
self.config.max_position_embeddings = new_num_position_embeddings
old_position_embeddings_weight = self.embeddings.position_embeddings.weight.clone()
self.embeddings.position_embeddings = nn.Embedding(self.config.max_position_embeddings, self.config.dim)
if self.config.sinusoidal_pos_embds:
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.embeddings.position_embeddings.weight, modifier_rank=0):
if torch.distributed.get_rank() == 0:
create_sinusoidal_embeddings(
n_pos=self.config.max_position_embeddings,
dim=self.config.dim,
out=self.embeddings.position_embeddings.weight,
)
else:
create_sinusoidal_embeddings(
n_pos=self.config.max_position_embeddings,
dim=self.config.dim,
out=self.embeddings.position_embeddings.weight,
)
else:
with torch.no_grad():
if num_position_embeds_diff > 0:
self.embeddings.position_embeddings.weight[:-num_position_embeds_diff] = nn.Parameter(
old_position_embeddings_weight
)
else:
self.embeddings.position_embeddings.weight = nn.Parameter(
old_position_embeddings_weight[:num_position_embeds_diff]
)
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
...@@ -525,6 +586,27 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -525,6 +586,27 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
self.mlm_loss_fct = nn.CrossEntropyLoss() self.mlm_loss_fct = nn.CrossEntropyLoss()
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings
"""
return self.distilbert.get_position_embeddings()
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embedding matrix. If position embeddings are learned, increasing the size
will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
size will add correct vectors at the end following the position encoding algorithm, whereas reducing
the size will remove vectors from the end.
"""
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
def get_output_embeddings(self): def get_output_embeddings(self):
return self.vocab_projector return self.vocab_projector
...@@ -608,6 +690,27 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel): ...@@ -608,6 +690,27 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
self.init_weights() self.init_weights()
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings
"""
return self.distilbert.get_position_embeddings()
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embedding matrix. If position embeddings are learned, increasing the size
will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
size will add correct vectors at the end following the position encoding algorithm, whereas reducing
the size will remove vectors from the end.
"""
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
...@@ -703,6 +806,27 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): ...@@ -703,6 +806,27 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
self.init_weights() self.init_weights()
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings
"""
return self.distilbert.get_position_embeddings()
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embedding matrix. If position embeddings are learned, increasing the size
will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
size will add correct vectors at the end following the position encoding algorithm, whereas reducing
the size will remove vectors from the end.
"""
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices")) @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
...@@ -799,6 +923,27 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel): ...@@ -799,6 +923,27 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
self.init_weights() self.init_weights()
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings
"""
return self.distilbert.get_position_embeddings()
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embedding matrix. If position embeddings are learned, increasing the size
will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
size will add correct vectors at the end following the position encoding algorithm, whereas reducing
the size will remove vectors from the end.
"""
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
...@@ -883,6 +1028,27 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel): ...@@ -883,6 +1028,27 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
self.init_weights() self.init_weights()
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings
"""
return self.distilbert.get_position_embeddings()
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`)
The number of new position embeddings. If position embeddings are learned, increasing the size will add
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
will remove vectors from the end.
"""
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
@add_start_docstrings_to_model_forward( @add_start_docstrings_to_model_forward(
DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
) )
......
...@@ -480,17 +480,6 @@ class PegasusPreTrainedModel(PreTrainedModel): ...@@ -480,17 +480,6 @@ class PegasusPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
dummy_inputs = {
"attention_mask": input_ids.ne(pad_token),
"input_ids": input_ids,
"decoder_input_ids": input_ids,
}
return dummy_inputs
PEGASUS_START_DOCSTRING = r""" PEGASUS_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
...@@ -658,6 +647,34 @@ class PegasusEncoder(PegasusPreTrainedModel): ...@@ -658,6 +647,34 @@ class PegasusEncoder(PegasusPreTrainedModel):
self.init_weights() self.init_weights()
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings matrix of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embeddings. If position embeddings are learned, increasing the size will add
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
will remove vectors from the end.
"""
logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
self.config.max_position_embeddings = new_num_position_embeddings
self.embed_positions = PegasusSinusoidalPositionalEmbedding(
self.config.max_position_embeddings,
self.config.d_model,
self.padding_idx,
)
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings matrix
"""
return self.embed_positions
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -848,6 +865,34 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -848,6 +865,34 @@ class PegasusDecoder(PegasusPreTrainedModel):
return combined_attention_mask return combined_attention_mask
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings matrix of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embeddings. If position embeddings are learned, increasing the size will add
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
will remove vectors from the end.
"""
logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
self.config.max_position_embeddings = new_num_position_embeddings
self.embed_positions = PegasusSinusoidalPositionalEmbedding(
self.config.max_position_embeddings,
self.config.d_model,
self.padding_idx,
)
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings matrix
"""
return self.embed_positions
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -1097,6 +1142,29 @@ class PegasusModel(PegasusPreTrainedModel): ...@@ -1097,6 +1142,29 @@ class PegasusModel(PegasusPreTrainedModel):
def get_decoder(self): def get_decoder(self):
return self.decoder return self.decoder
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings matrix of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embeddings. If position embeddings are learned, increasing the size will add
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
will remove vectors from the end.
"""
self.config.max_position_embeddings = new_num_position_embeddings
self.encoder.resize_position_embeddings(new_num_position_embeddings)
self.decoder.resize_position_embeddings(new_num_position_embeddings)
def get_position_embeddings(self) -> Tuple[nn.Embedding]:
"""
Returns the position embeddings matrix
"""
return (self.encoder.get_position_embeddings(), self.decoder.get_position_embeddings())
@add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
...@@ -1237,6 +1305,29 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): ...@@ -1237,6 +1305,29 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings self.lm_head = new_embeddings
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings matrix of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embeddings. If position embeddings are learned, increasing the size will add
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
will remove vectors from the end.
"""
self.config.max_position_embeddings = new_num_position_embeddings
self.model.encoder.resize_position_embeddings(new_num_position_embeddings)
self.model.decoder.resize_position_embeddings(new_num_position_embeddings)
def get_position_embeddings(self) -> Tuple[nn.Embedding]:
"""
Returns the position embeddings matrix
"""
return (self.model.encoder.get_position_embeddings(), self.model.decoder.get_position_embeddings())
@add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_end_docstrings(PEGASUS_GENERATION_EXAMPLE) @add_end_docstrings(PEGASUS_GENERATION_EXAMPLE)
...@@ -1373,7 +1464,6 @@ class PegasusDecoderWrapper(PegasusPreTrainedModel): ...@@ -1373,7 +1464,6 @@ class PegasusDecoderWrapper(PegasusPreTrainedModel):
return self.decoder(*args, **kwargs) return self.decoder(*args, **kwargs)
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Pegasus
class PegasusForCausalLM(PegasusPreTrainedModel): class PegasusForCausalLM(PegasusPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1404,7 +1494,30 @@ class PegasusForCausalLM(PegasusPreTrainedModel): ...@@ -1404,7 +1494,30 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
def get_decoder(self): def get_decoder(self):
return self.model.decoder return self.model.decoder
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings matrix
"""
return self.model.decoder.get_position_embeddings()
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings matrix of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embeddings. If position embeddings are learned, increasing the size will add
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
will remove vectors from the end.
"""
self.config.max_position_embeddings = new_num_position_embeddings
self.model.decoder.resize_position_embeddings(new_num_position_embeddings)
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM.forward with Bart->Pegasus
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
......
...@@ -94,6 +94,7 @@ class ModelTesterMixin: ...@@ -94,6 +94,7 @@ class ModelTesterMixin:
test_torchscript = True test_torchscript = True
test_pruning = True test_pruning = True
test_resize_embeddings = True test_resize_embeddings = True
test_resize_position_embeddings = False
test_head_masking = True test_head_masking = True
test_missing_keys = True test_missing_keys = True
test_model_parallel = False test_model_parallel = False
...@@ -1067,6 +1068,85 @@ class ModelTesterMixin: ...@@ -1067,6 +1068,85 @@ class ModelTesterMixin:
hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0] hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3)) self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
def test_resize_position_vector_embeddings(self):
if not self.test_resize_position_embeddings:
return
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
if self.model_tester.is_training is False:
model.eval()
max_position_embeddings = config.max_position_embeddings
# Retrieve the embeddings and clone theme
if model.config.is_encoder_decoder:
encoder_model_embed, decoder_model_embed = model.get_position_embeddings()
encoder_cloned_embeddings = encoder_model_embed.weight.clone()
decoder_cloned_embeddings = decoder_model_embed.weight.clone()
else:
model_embed = model.get_position_embeddings()
cloned_embeddings = model_embed.weight.clone()
# Check that resizing the position embeddings with a larger max_position_embeddings increases
# the model's postion embeddings size
model.resize_position_embeddings(max_position_embeddings + 10)
self.assertEqual(model.config.max_position_embeddings, max_position_embeddings + 10)
# Check that it actually resizes the embeddings matrix
if model.config.is_encoder_decoder:
encoder_model_embed, decoder_model_embed = model.get_position_embeddings()
self.assertEqual(encoder_model_embed.weight.shape[0], encoder_cloned_embeddings.shape[0] + 10)
self.assertEqual(decoder_model_embed.weight.shape[0], decoder_cloned_embeddings.shape[0] + 10)
else:
model_embed = model.get_position_embeddings()
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that resizing the position embeddings with a smaller max_position_embeddings decreases
# the model's max_position_embeddings
model.resize_position_embeddings(max_position_embeddings - 5)
self.assertEqual(model.config.max_position_embeddings, max_position_embeddings - 5)
# Check that it actually resizes the embeddings matrix
if model.config.is_encoder_decoder:
encoder_model_embed, decoder_model_embed = model.get_position_embeddings()
self.assertEqual(encoder_model_embed.weight.shape[0], encoder_cloned_embeddings.shape[0] - 5)
self.assertEqual(decoder_model_embed.weight.shape[0], decoder_cloned_embeddings.shape[0] - 5)
else:
model_embed = model.get_position_embeddings()
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 5)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
models_equal = True
if model.config.is_encoder_decoder:
for p1, p2 in zip(encoder_cloned_embeddings, encoder_model_embed.weight):
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
for p1, p2 in zip(decoder_cloned_embeddings, decoder_model_embed.weight):
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
else:
for p1, p2 in zip(cloned_embeddings, model_embed.weight):
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
( (
original_config, original_config,
......
...@@ -214,6 +214,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -214,6 +214,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
test_torchscript = True test_torchscript = True
test_resize_embeddings = True test_resize_embeddings = True
test_sequence_classification_problem_types = True test_sequence_classification_problem_types = True
test_resize_position_embeddings = True
def setUp(self): def setUp(self):
self.model_tester = DistilBertModelTester(self) self.model_tester = DistilBertModelTester(self)
......
...@@ -229,6 +229,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas ...@@ -229,6 +229,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else () all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_resize_position_embeddings = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
...@@ -526,6 +527,7 @@ class PegasusStandaloneDecoderModelTester: ...@@ -526,6 +527,7 @@ class PegasusStandaloneDecoderModelTester:
class PegasusStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class PegasusStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (PegasusDecoder, PegasusForCausalLM) if is_torch_available() else () all_model_classes = (PegasusDecoder, PegasusForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (PegasusForCausalLM,) if is_torch_available() else () all_generative_model_classes = (PegasusForCausalLM,) if is_torch_available() else ()
test_resize_position_embeddings = True
test_pruning = False test_pruning = False
is_encoder_decoder = False is_encoder_decoder = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment