"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "81262c7b7296269cd40f80d6f735812b1c941c08"
Unverified Commit 9cc9f412 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Make ProphetNetModel really compatible with EncoderDecoder (#9033)

* improve

* finish

* upload model

* fix lm head

* fix test
parent 24f6cdea
...@@ -1886,7 +1886,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): ...@@ -1886,7 +1886,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
config = copy.deepcopy(config) config = copy.deepcopy(config)
config.is_decoder = True config.is_decoder = True
config.is_encoder_decoder = False config.is_encoder_decoder = False
self.decoder = ProphetNetDecoder(config) self.prophetnet = ProphetNetDecoderWrapper(config)
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.disable_ngram_loss = config.disable_ngram_loss self.disable_ngram_loss = config.disable_ngram_loss
...@@ -1896,10 +1896,10 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): ...@@ -1896,10 +1896,10 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
self.init_weights() self.init_weights()
def get_input_embeddings(self): def get_input_embeddings(self):
return self.decoder.word_embeddings return self.prophetnet.decoder.word_embeddings
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.decoder.word_embeddings = value self.prophetnet.decoder.word_embeddings = value
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
...@@ -1907,6 +1907,12 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): ...@@ -1907,6 +1907,12 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.prophetnet.decoder = decoder
def get_decoder(self):
return self.prophetnet.decoder
@add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=ProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=ProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
...@@ -1956,7 +1962,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): ...@@ -1956,7 +1962,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
>>> import torch >>> import torch
>>> tokenizer = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased') >>> tokenizer = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
>>> model = ProphetNetForCausalLM.from_pretrained('patrickvonplaten/prophetnet-decoder-clm-large-uncased') >>> model = ProphetNetForCausalLM.from_pretrained('microsoft/prophetnet-large-uncased')
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
...@@ -1969,7 +1975,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): ...@@ -1969,7 +1975,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
>>> tokenizer_enc = BertTokenizer.from_pretrained('bert-large-uncased') >>> tokenizer_enc = BertTokenizer.from_pretrained('bert-large-uncased')
>>> tokenizer_dec = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased') >>> tokenizer_dec = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "patrickvonplaten/prophetnet-decoder-clm-large-uncased") >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "microsoft/prophetnet-large-uncased")
>>> ARTICLE = ( >>> ARTICLE = (
... "the us state department said wednesday it had received no " ... "the us state department said wednesday it had received no "
...@@ -1985,7 +1991,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): ...@@ -1985,7 +1991,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.decoder( outputs = self.prophetnet.decoder(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -2086,8 +2092,16 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): ...@@ -2086,8 +2092,16 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
reordered_past.append(layer_past_new) reordered_past.append(layer_past_new)
return reordered_past return reordered_past
def set_decoder(self, decoder):
self.decoder = decoder
def get_decoder(self): class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel):
return self.decoder """
This is a wrapper class, so that :class:`~transformers.ProphetNetForCausalLM` can correctly be loaded from
pretrained prophetnet classes.
"""
def __init__(self, config):
super().__init__(config)
self.decoder = ProphetNetDecoder(config)
def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
...@@ -136,7 +136,7 @@ class XLMProphetNetForCausalLM(ProphetNetForCausalLM): ...@@ -136,7 +136,7 @@ class XLMProphetNetForCausalLM(ProphetNetForCausalLM):
>>> import torch >>> import torch
>>> tokenizer = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased') >>> tokenizer = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> model = XLMProphetNetForCausalLM.from_pretrained('patrickvonplaten/xprophetnet-decoder-clm-large-uncased') >>> model = XLMProphetNetForCausalLM.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
...@@ -149,7 +149,7 @@ class XLMProphetNetForCausalLM(ProphetNetForCausalLM): ...@@ -149,7 +149,7 @@ class XLMProphetNetForCausalLM(ProphetNetForCausalLM):
>>> tokenizer_enc = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') >>> tokenizer_enc = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
>>> tokenizer_dec = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased') >>> tokenizer_dec = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("xlm-roberta-large", "patrickvonplaten/xprophetnet-decoder-clm-large-uncased") >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("xlm-roberta-large", 'microsoft/xprophetnet-large-wiki100-cased')
>>> ARTICLE = ( >>> ARTICLE = (
... "the us state department said wednesday it had received no " ... "the us state department said wednesday it had received no "
......
...@@ -802,9 +802,7 @@ class ProphetNetEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): ...@@ -802,9 +802,7 @@ class ProphetNetEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
} }
def get_pretrained_model(self): def get_pretrained_model(self):
return EncoderDecoderModel.from_encoder_decoder_pretrained( return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "prophetnet-large-uncased")
"bert-large-uncased", "patrickvonplaten/prophetnet-decoder-clm-large-uncased"
)
def test_encoder_decoder_model_shared_weights(self): def test_encoder_decoder_model_shared_weights(self):
pass pass
...@@ -38,6 +38,7 @@ if is_torch_available(): ...@@ -38,6 +38,7 @@ if is_torch_available():
ProphetNetModel, ProphetNetModel,
ProphetNetTokenizer, ProphetNetTokenizer,
) )
from transformers.modeling_outputs import BaseModelOutput
class ProphetNetModelTester: class ProphetNetModelTester:
...@@ -467,6 +468,31 @@ class ProphetNetModelTester: ...@@ -467,6 +468,31 @@ class ProphetNetModelTester:
) )
) )
def check_causal_lm_from_pretrained(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, *args
):
model = ProphetNetForConditionalGeneration(config).to(torch_device).eval()
with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)
decoder = ProphetNetForCausalLM.from_pretrained(tmp_dirname).to(torch_device)
encoder_hidden_states = model.prophetnet.encoder(input_ids).last_hidden_state
model_outputs = model(
encoder_outputs=BaseModelOutput(last_hidden_state=encoder_hidden_states),
decoder_input_ids=decoder_input_ids,
)
dec_outputs = decoder(encoder_hidden_states=encoder_hidden_states, input_ids=decoder_input_ids)
self.parent.assertTrue(
torch.allclose(
model_outputs.logits[0, :5],
dec_outputs.logits[0, :5],
atol=1e-3,
)
)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -898,6 +924,10 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test ...@@ -898,6 +924,10 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
self.assertFalse(config.add_cross_attention) self.assertFalse(config.add_cross_attention)
def test_causal_lm_from_pretrained(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_causal_lm_from_pretrained(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision") @unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_fp16_forward(self): def test_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
......
...@@ -34,6 +34,7 @@ IGNORE_NON_TESTED = [ ...@@ -34,6 +34,7 @@ IGNORE_NON_TESTED = [
"BertLMHeadModel", # Needs to be setup as decoder. "BertLMHeadModel", # Needs to be setup as decoder.
"DPREncoder", # Building part of bigger (tested) model. "DPREncoder", # Building part of bigger (tested) model.
"DPRSpanPredictor", # Building part of bigger (tested) model. "DPRSpanPredictor", # Building part of bigger (tested) model.
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
"ReformerForMaskedLM", # Needs to be setup as decoder. "ReformerForMaskedLM", # Needs to be setup as decoder.
"T5Stack", # Building part of bigger (tested) model. "T5Stack", # Building part of bigger (tested) model.
"TFDPREncoder", # Building part of bigger (tested) model. "TFDPREncoder", # Building part of bigger (tested) model.
...@@ -74,6 +75,7 @@ IGNORE_NON_AUTO_CONFIGURED = [ ...@@ -74,6 +75,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
"OpenAIGPTDoubleHeadsModel", "OpenAIGPTDoubleHeadsModel",
"ProphetNetDecoder", "ProphetNetDecoder",
"ProphetNetEncoder", "ProphetNetEncoder",
"ProphetNetDecoderWrapper",
"RagModel", "RagModel",
"RagSequenceForGeneration", "RagSequenceForGeneration",
"RagTokenForGeneration", "RagTokenForGeneration",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment