Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9cc9f412
Unverified
Commit
9cc9f412
authored
Dec 11, 2020
by
Patrick von Platen
Committed by
GitHub
Dec 11, 2020
Browse files
Make ProphetNetModel really compatible with EncoderDecoder (#9033)
* improve * finish * upload model * fix lm head * fix test
parent
24f6cdea
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
59 additions
and
15 deletions
+59
-15
src/transformers/models/prophetnet/modeling_prophetnet.py
src/transformers/models/prophetnet/modeling_prophetnet.py
+24
-10
src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py
...sformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py
+2
-2
tests/test_modeling_encoder_decoder.py
tests/test_modeling_encoder_decoder.py
+1
-3
tests/test_modeling_prophetnet.py
tests/test_modeling_prophetnet.py
+30
-0
utils/check_repo.py
utils/check_repo.py
+2
-0
No files found.
src/transformers/models/prophetnet/modeling_prophetnet.py
View file @
9cc9f412
...
@@ -1886,7 +1886,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
...
@@ -1886,7 +1886,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
config
.
is_decoder
=
True
config
.
is_decoder
=
True
config
.
is_encoder_decoder
=
False
config
.
is_encoder_decoder
=
False
self
.
decoder
=
ProphetNetDecoder
(
config
)
self
.
prophetnet
=
ProphetNetDecoder
Wrapper
(
config
)
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
self
.
disable_ngram_loss
=
config
.
disable_ngram_loss
self
.
disable_ngram_loss
=
config
.
disable_ngram_loss
...
@@ -1896,10 +1896,10 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
...
@@ -1896,10 +1896,10 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
):
return
self
.
decoder
.
word_embeddings
return
self
.
prophetnet
.
decoder
.
word_embeddings
def
set_input_embeddings
(
self
,
value
):
def
set_input_embeddings
(
self
,
value
):
self
.
decoder
.
word_embeddings
=
value
self
.
prophetnet
.
decoder
.
word_embeddings
=
value
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
return
self
.
lm_head
return
self
.
lm_head
...
@@ -1907,6 +1907,12 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
...
@@ -1907,6 +1907,12 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
def
set_output_embeddings
(
self
,
new_embeddings
):
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
lm_head
=
new_embeddings
self
.
lm_head
=
new_embeddings
def
set_decoder
(
self
,
decoder
):
self
.
prophetnet
.
decoder
=
decoder
def
get_decoder
(
self
):
return
self
.
prophetnet
.
decoder
@
add_start_docstrings_to_model_forward
(
PROPHETNET_STANDALONE_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
PROPHETNET_STANDALONE_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
ProphetNetDecoderLMOutput
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
ProphetNetDecoderLMOutput
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
def
forward
(
...
@@ -1956,7 +1962,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
...
@@ -1956,7 +1962,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
>>> import torch
>>> import torch
>>> tokenizer = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
>>> tokenizer = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
>>> model = ProphetNetForCausalLM.from_pretrained('
patrickvonplaten/prophetnet-decoder-clm
-large-uncased')
>>> model = ProphetNetForCausalLM.from_pretrained('
microsoft/prophetnet
-large-uncased')
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> outputs = model(**inputs)
...
@@ -1969,7 +1975,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
...
@@ -1969,7 +1975,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
>>> tokenizer_enc = BertTokenizer.from_pretrained('bert-large-uncased')
>>> tokenizer_enc = BertTokenizer.from_pretrained('bert-large-uncased')
>>> tokenizer_dec = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
>>> tokenizer_dec = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "
patrickvonplaten/prophetnet-decoder-clm
-large-uncased")
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "
microsoft/prophetnet
-large-uncased")
>>> ARTICLE = (
>>> ARTICLE = (
... "the us state department said wednesday it had received no "
... "the us state department said wednesday it had received no "
...
@@ -1985,7 +1991,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
...
@@ -1985,7 +1991,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs
=
self
.
decoder
(
outputs
=
self
.
prophetnet
.
decoder
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
...
@@ -2086,8 +2092,16 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
...
@@ -2086,8 +2092,16 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
reordered_past
.
append
(
layer_past_new
)
reordered_past
.
append
(
layer_past_new
)
return
reordered_past
return
reordered_past
def
set_decoder
(
self
,
decoder
):
self
.
decoder
=
decoder
def
get_decoder
(
self
):
class
ProphetNetDecoderWrapper
(
ProphetNetPreTrainedModel
):
return
self
.
decoder
"""
This is a wrapper class, so that :class:`~transformers.ProphetNetForCausalLM` can correctly be loaded from
pretrained prophetnet classes.
"""
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
decoder
=
ProphetNetDecoder
(
config
)
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
decoder
(
*
args
,
**
kwargs
)
src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py
View file @
9cc9f412
...
@@ -136,7 +136,7 @@ class XLMProphetNetForCausalLM(ProphetNetForCausalLM):
...
@@ -136,7 +136,7 @@ class XLMProphetNetForCausalLM(ProphetNetForCausalLM):
>>> import torch
>>> import torch
>>> tokenizer = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> tokenizer = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> model = XLMProphetNetForCausalLM.from_pretrained('
patrickvonplaten/xprophetnet-decoder-clm-large-un
cased')
>>> model = XLMProphetNetForCausalLM.from_pretrained('
microsoft/xprophetnet-large-wiki100-
cased')
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> outputs = model(**inputs)
...
@@ -149,7 +149,7 @@ class XLMProphetNetForCausalLM(ProphetNetForCausalLM):
...
@@ -149,7 +149,7 @@ class XLMProphetNetForCausalLM(ProphetNetForCausalLM):
>>> tokenizer_enc = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
>>> tokenizer_enc = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
>>> tokenizer_dec = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> tokenizer_dec = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("xlm-roberta-large",
"patrickvonplaten/xprophetnet-decoder-clm-large-un
cased
"
)
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("xlm-roberta-large",
'microsoft/xprophetnet-large-wiki100-
cased
'
)
>>> ARTICLE = (
>>> ARTICLE = (
... "the us state department said wednesday it had received no "
... "the us state department said wednesday it had received no "
...
...
tests/test_modeling_encoder_decoder.py
View file @
9cc9f412
...
@@ -802,9 +802,7 @@ class ProphetNetEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
...
@@ -802,9 +802,7 @@ class ProphetNetEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
}
}
def
get_pretrained_model
(
self
):
def
get_pretrained_model
(
self
):
return
EncoderDecoderModel
.
from_encoder_decoder_pretrained
(
return
EncoderDecoderModel
.
from_encoder_decoder_pretrained
(
"bert-large-uncased"
,
"prophetnet-large-uncased"
)
"bert-large-uncased"
,
"patrickvonplaten/prophetnet-decoder-clm-large-uncased"
)
def
test_encoder_decoder_model_shared_weights
(
self
):
def
test_encoder_decoder_model_shared_weights
(
self
):
pass
pass
tests/test_modeling_prophetnet.py
View file @
9cc9f412
...
@@ -38,6 +38,7 @@ if is_torch_available():
...
@@ -38,6 +38,7 @@ if is_torch_available():
ProphetNetModel
,
ProphetNetModel
,
ProphetNetTokenizer
,
ProphetNetTokenizer
,
)
)
from
transformers.modeling_outputs
import
BaseModelOutput
class
ProphetNetModelTester
:
class
ProphetNetModelTester
:
...
@@ -467,6 +468,31 @@ class ProphetNetModelTester:
...
@@ -467,6 +468,31 @@ class ProphetNetModelTester:
)
)
)
)
def
check_causal_lm_from_pretrained
(
self
,
config
,
input_ids
,
decoder_input_ids
,
attention_mask
,
decoder_attention_mask
,
*
args
):
model
=
ProphetNetForConditionalGeneration
(
config
).
to
(
torch_device
).
eval
()
with
tempfile
.
TemporaryDirectory
()
as
tmp_dirname
:
model
.
save_pretrained
(
tmp_dirname
)
decoder
=
ProphetNetForCausalLM
.
from_pretrained
(
tmp_dirname
).
to
(
torch_device
)
encoder_hidden_states
=
model
.
prophetnet
.
encoder
(
input_ids
).
last_hidden_state
model_outputs
=
model
(
encoder_outputs
=
BaseModelOutput
(
last_hidden_state
=
encoder_hidden_states
),
decoder_input_ids
=
decoder_input_ids
,
)
dec_outputs
=
decoder
(
encoder_hidden_states
=
encoder_hidden_states
,
input_ids
=
decoder_input_ids
)
self
.
parent
.
assertTrue
(
torch
.
allclose
(
model_outputs
.
logits
[
0
,
:
5
],
dec_outputs
.
logits
[
0
,
:
5
],
atol
=
1e-3
,
)
)
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
(
...
@@ -898,6 +924,10 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
...
@@ -898,6 +924,10 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
self
.
assertFalse
(
config
.
add_cross_attention
)
self
.
assertFalse
(
config
.
add_cross_attention
)
def
test_causal_lm_from_pretrained
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
check_causal_lm_from_pretrained
(
*
config_and_inputs
)
@
unittest
.
skipIf
(
torch_device
==
"cpu"
,
"Cant do half precision"
)
@
unittest
.
skipIf
(
torch_device
==
"cpu"
,
"Cant do half precision"
)
def
test_fp16_forward
(
self
):
def
test_fp16_forward
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
...
...
utils/check_repo.py
View file @
9cc9f412
...
@@ -34,6 +34,7 @@ IGNORE_NON_TESTED = [
...
@@ -34,6 +34,7 @@ IGNORE_NON_TESTED = [
"BertLMHeadModel"
,
# Needs to be setup as decoder.
"BertLMHeadModel"
,
# Needs to be setup as decoder.
"DPREncoder"
,
# Building part of bigger (tested) model.
"DPREncoder"
,
# Building part of bigger (tested) model.
"DPRSpanPredictor"
,
# Building part of bigger (tested) model.
"DPRSpanPredictor"
,
# Building part of bigger (tested) model.
"ProphetNetDecoderWrapper"
,
# Building part of bigger (tested) model.
"ReformerForMaskedLM"
,
# Needs to be setup as decoder.
"ReformerForMaskedLM"
,
# Needs to be setup as decoder.
"T5Stack"
,
# Building part of bigger (tested) model.
"T5Stack"
,
# Building part of bigger (tested) model.
"TFDPREncoder"
,
# Building part of bigger (tested) model.
"TFDPREncoder"
,
# Building part of bigger (tested) model.
...
@@ -74,6 +75,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
...
@@ -74,6 +75,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
"OpenAIGPTDoubleHeadsModel"
,
"OpenAIGPTDoubleHeadsModel"
,
"ProphetNetDecoder"
,
"ProphetNetDecoder"
,
"ProphetNetEncoder"
,
"ProphetNetEncoder"
,
"ProphetNetDecoderWrapper"
,
"RagModel"
,
"RagModel"
,
"RagSequenceForGeneration"
,
"RagSequenceForGeneration"
,
"RagTokenForGeneration"
,
"RagTokenForGeneration"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment