Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9cc9f412
Unverified
Commit
9cc9f412
authored
Dec 11, 2020
by
Patrick von Platen
Committed by
GitHub
Dec 11, 2020
Browse files
Make ProphetNetModel really compatible with EncoderDecoder (#9033)
* improve * finish * upload model * fix lm head * fix test
parent
24f6cdea
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
59 additions
and
15 deletions
+59
-15
src/transformers/models/prophetnet/modeling_prophetnet.py
src/transformers/models/prophetnet/modeling_prophetnet.py
+24
-10
src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py
...sformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py
+2
-2
tests/test_modeling_encoder_decoder.py
tests/test_modeling_encoder_decoder.py
+1
-3
tests/test_modeling_prophetnet.py
tests/test_modeling_prophetnet.py
+30
-0
utils/check_repo.py
utils/check_repo.py
+2
-0
No files found.
src/transformers/models/prophetnet/modeling_prophetnet.py
View file @
9cc9f412
...
...
@@ -1886,7 +1886,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
config
=
copy
.
deepcopy
(
config
)
config
.
is_decoder
=
True
config
.
is_encoder_decoder
=
False
self
.
decoder
=
ProphetNetDecoder
(
config
)
self
.
prophetnet
=
ProphetNetDecoder
Wrapper
(
config
)
self
.
padding_idx
=
config
.
pad_token_id
self
.
disable_ngram_loss
=
config
.
disable_ngram_loss
...
...
@@ -1896,10 +1896,10 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
self
.
init_weights
()
def
get_input_embeddings
(
self
):
return
self
.
decoder
.
word_embeddings
return
self
.
prophetnet
.
decoder
.
word_embeddings
def
set_input_embeddings
(
self
,
value
):
self
.
decoder
.
word_embeddings
=
value
self
.
prophetnet
.
decoder
.
word_embeddings
=
value
def
get_output_embeddings
(
self
):
return
self
.
lm_head
...
...
@@ -1907,6 +1907,12 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
lm_head
=
new_embeddings
def
set_decoder
(
self
,
decoder
):
self
.
prophetnet
.
decoder
=
decoder
def
get_decoder
(
self
):
return
self
.
prophetnet
.
decoder
@
add_start_docstrings_to_model_forward
(
PROPHETNET_STANDALONE_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
ProphetNetDecoderLMOutput
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
...
...
@@ -1956,7 +1962,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
>>> import torch
>>> tokenizer = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
>>> model = ProphetNetForCausalLM.from_pretrained('
patrickvonplaten/prophetnet-decoder-clm
-large-uncased')
>>> model = ProphetNetForCausalLM.from_pretrained('
microsoft/prophetnet
-large-uncased')
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
...
...
@@ -1969,7 +1975,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
>>> tokenizer_enc = BertTokenizer.from_pretrained('bert-large-uncased')
>>> tokenizer_dec = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "
patrickvonplaten/prophetnet-decoder-clm
-large-uncased")
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "
microsoft/prophetnet
-large-uncased")
>>> ARTICLE = (
... "the us state department said wednesday it had received no "
...
...
@@ -1985,7 +1991,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs
=
self
.
decoder
(
outputs
=
self
.
prophetnet
.
decoder
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
...
...
@@ -2086,8 +2092,16 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
reordered_past
.
append
(
layer_past_new
)
return
reordered_past
def
set_decoder
(
self
,
decoder
):
self
.
decoder
=
decoder
def
get_decoder
(
self
):
return
self
.
decoder
class
ProphetNetDecoderWrapper
(
ProphetNetPreTrainedModel
):
"""
This is a wrapper class, so that :class:`~transformers.ProphetNetForCausalLM` can correctly be loaded from
pretrained prophetnet classes.
"""
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
decoder
=
ProphetNetDecoder
(
config
)
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
decoder
(
*
args
,
**
kwargs
)
src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py
View file @
9cc9f412
...
...
@@ -136,7 +136,7 @@ class XLMProphetNetForCausalLM(ProphetNetForCausalLM):
>>> import torch
>>> tokenizer = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> model = XLMProphetNetForCausalLM.from_pretrained('
patrickvonplaten/xprophetnet-decoder-clm-large-un
cased')
>>> model = XLMProphetNetForCausalLM.from_pretrained('
microsoft/xprophetnet-large-wiki100-
cased')
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
...
...
@@ -149,7 +149,7 @@ class XLMProphetNetForCausalLM(ProphetNetForCausalLM):
>>> tokenizer_enc = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
>>> tokenizer_dec = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("xlm-roberta-large",
"patrickvonplaten/xprophetnet-decoder-clm-large-un
cased
"
)
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("xlm-roberta-large",
'microsoft/xprophetnet-large-wiki100-
cased
'
)
>>> ARTICLE = (
... "the us state department said wednesday it had received no "
...
...
tests/test_modeling_encoder_decoder.py
View file @
9cc9f412
...
...
@@ -802,9 +802,7 @@ class ProphetNetEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
}
def
get_pretrained_model
(
self
):
return
EncoderDecoderModel
.
from_encoder_decoder_pretrained
(
"bert-large-uncased"
,
"patrickvonplaten/prophetnet-decoder-clm-large-uncased"
)
return
EncoderDecoderModel
.
from_encoder_decoder_pretrained
(
"bert-large-uncased"
,
"prophetnet-large-uncased"
)
def
test_encoder_decoder_model_shared_weights
(
self
):
pass
tests/test_modeling_prophetnet.py
View file @
9cc9f412
...
...
@@ -38,6 +38,7 @@ if is_torch_available():
ProphetNetModel
,
ProphetNetTokenizer
,
)
from
transformers.modeling_outputs
import
BaseModelOutput
class
ProphetNetModelTester
:
...
...
@@ -467,6 +468,31 @@ class ProphetNetModelTester:
)
)
def
check_causal_lm_from_pretrained
(
self
,
config
,
input_ids
,
decoder_input_ids
,
attention_mask
,
decoder_attention_mask
,
*
args
):
model
=
ProphetNetForConditionalGeneration
(
config
).
to
(
torch_device
).
eval
()
with
tempfile
.
TemporaryDirectory
()
as
tmp_dirname
:
model
.
save_pretrained
(
tmp_dirname
)
decoder
=
ProphetNetForCausalLM
.
from_pretrained
(
tmp_dirname
).
to
(
torch_device
)
encoder_hidden_states
=
model
.
prophetnet
.
encoder
(
input_ids
).
last_hidden_state
model_outputs
=
model
(
encoder_outputs
=
BaseModelOutput
(
last_hidden_state
=
encoder_hidden_states
),
decoder_input_ids
=
decoder_input_ids
,
)
dec_outputs
=
decoder
(
encoder_hidden_states
=
encoder_hidden_states
,
input_ids
=
decoder_input_ids
)
self
.
parent
.
assertTrue
(
torch
.
allclose
(
model_outputs
.
logits
[
0
,
:
5
],
dec_outputs
.
logits
[
0
,
:
5
],
atol
=
1e-3
,
)
)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
...
...
@@ -898,6 +924,10 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
self
.
assertFalse
(
config
.
add_cross_attention
)
def
test_causal_lm_from_pretrained
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
check_causal_lm_from_pretrained
(
*
config_and_inputs
)
@
unittest
.
skipIf
(
torch_device
==
"cpu"
,
"Cant do half precision"
)
def
test_fp16_forward
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
...
...
utils/check_repo.py
View file @
9cc9f412
...
...
@@ -34,6 +34,7 @@ IGNORE_NON_TESTED = [
"BertLMHeadModel"
,
# Needs to be setup as decoder.
"DPREncoder"
,
# Building part of bigger (tested) model.
"DPRSpanPredictor"
,
# Building part of bigger (tested) model.
"ProphetNetDecoderWrapper"
,
# Building part of bigger (tested) model.
"ReformerForMaskedLM"
,
# Needs to be setup as decoder.
"T5Stack"
,
# Building part of bigger (tested) model.
"TFDPREncoder"
,
# Building part of bigger (tested) model.
...
...
@@ -74,6 +75,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
"OpenAIGPTDoubleHeadsModel"
,
"ProphetNetDecoder"
,
"ProphetNetEncoder"
,
"ProphetNetDecoderWrapper"
,
"RagModel"
,
"RagSequenceForGeneration"
,
"RagTokenForGeneration"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment