Unverified Commit 9cc9f412 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Make ProphetNetModel really compatible with EncoderDecoder (#9033)

* improve

* finish

* upload model

* fix lm head

* fix test
parent 24f6cdea
......@@ -1886,7 +1886,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
config = copy.deepcopy(config)
config.is_decoder = True
config.is_encoder_decoder = False
self.decoder = ProphetNetDecoder(config)
self.prophetnet = ProphetNetDecoderWrapper(config)
self.padding_idx = config.pad_token_id
self.disable_ngram_loss = config.disable_ngram_loss
......@@ -1896,10 +1896,10 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
self.init_weights()
def get_input_embeddings(self):
return self.decoder.word_embeddings
return self.prophetnet.decoder.word_embeddings
def set_input_embeddings(self, value):
self.decoder.word_embeddings = value
self.prophetnet.decoder.word_embeddings = value
def get_output_embeddings(self):
return self.lm_head
......@@ -1907,6 +1907,12 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.prophetnet.decoder = decoder
def get_decoder(self):
return self.prophetnet.decoder
@add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=ProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
......@@ -1956,7 +1962,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
>>> import torch
>>> tokenizer = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
>>> model = ProphetNetForCausalLM.from_pretrained('patrickvonplaten/prophetnet-decoder-clm-large-uncased')
>>> model = ProphetNetForCausalLM.from_pretrained('microsoft/prophetnet-large-uncased')
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
......@@ -1969,7 +1975,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
>>> tokenizer_enc = BertTokenizer.from_pretrained('bert-large-uncased')
>>> tokenizer_dec = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "patrickvonplaten/prophetnet-decoder-clm-large-uncased")
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "microsoft/prophetnet-large-uncased")
>>> ARTICLE = (
... "the us state department said wednesday it had received no "
......@@ -1985,7 +1991,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.decoder(
outputs = self.prophetnet.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
......@@ -2086,8 +2092,16 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
reordered_past.append(layer_past_new)
return reordered_past
def set_decoder(self, decoder):
self.decoder = decoder
def get_decoder(self):
return self.decoder
class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel):
"""
This is a wrapper class, so that :class:`~transformers.ProphetNetForCausalLM` can correctly be loaded from
pretrained prophetnet classes.
"""
def __init__(self, config):
super().__init__(config)
self.decoder = ProphetNetDecoder(config)
def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
......@@ -136,7 +136,7 @@ class XLMProphetNetForCausalLM(ProphetNetForCausalLM):
>>> import torch
>>> tokenizer = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> model = XLMProphetNetForCausalLM.from_pretrained('patrickvonplaten/xprophetnet-decoder-clm-large-uncased')
>>> model = XLMProphetNetForCausalLM.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
......@@ -149,7 +149,7 @@ class XLMProphetNetForCausalLM(ProphetNetForCausalLM):
>>> tokenizer_enc = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
>>> tokenizer_dec = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("xlm-roberta-large", "patrickvonplaten/xprophetnet-decoder-clm-large-uncased")
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("xlm-roberta-large", 'microsoft/xprophetnet-large-wiki100-cased')
>>> ARTICLE = (
... "the us state department said wednesday it had received no "
......
......@@ -802,9 +802,7 @@ class ProphetNetEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
}
def get_pretrained_model(self):
return EncoderDecoderModel.from_encoder_decoder_pretrained(
"bert-large-uncased", "patrickvonplaten/prophetnet-decoder-clm-large-uncased"
)
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "prophetnet-large-uncased")
def test_encoder_decoder_model_shared_weights(self):
pass
......@@ -38,6 +38,7 @@ if is_torch_available():
ProphetNetModel,
ProphetNetTokenizer,
)
from transformers.modeling_outputs import BaseModelOutput
class ProphetNetModelTester:
......@@ -467,6 +468,31 @@ class ProphetNetModelTester:
)
)
def check_causal_lm_from_pretrained(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, *args
):
model = ProphetNetForConditionalGeneration(config).to(torch_device).eval()
with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)
decoder = ProphetNetForCausalLM.from_pretrained(tmp_dirname).to(torch_device)
encoder_hidden_states = model.prophetnet.encoder(input_ids).last_hidden_state
model_outputs = model(
encoder_outputs=BaseModelOutput(last_hidden_state=encoder_hidden_states),
decoder_input_ids=decoder_input_ids,
)
dec_outputs = decoder(encoder_hidden_states=encoder_hidden_states, input_ids=decoder_input_ids)
self.parent.assertTrue(
torch.allclose(
model_outputs.logits[0, :5],
dec_outputs.logits[0, :5],
atol=1e-3,
)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
......@@ -898,6 +924,10 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
self.assertFalse(config.add_cross_attention)
def test_causal_lm_from_pretrained(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_causal_lm_from_pretrained(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
......
......@@ -34,6 +34,7 @@ IGNORE_NON_TESTED = [
"BertLMHeadModel", # Needs to be setup as decoder.
"DPREncoder", # Building part of bigger (tested) model.
"DPRSpanPredictor", # Building part of bigger (tested) model.
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
"ReformerForMaskedLM", # Needs to be setup as decoder.
"T5Stack", # Building part of bigger (tested) model.
"TFDPREncoder", # Building part of bigger (tested) model.
......@@ -74,6 +75,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
"OpenAIGPTDoubleHeadsModel",
"ProphetNetDecoder",
"ProphetNetEncoder",
"ProphetNetDecoderWrapper",
"RagModel",
"RagSequenceForGeneration",
"RagTokenForGeneration",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment