Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5bf4465e
Unverified
Commit
5bf4465e
authored
Aug 20, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 20, 2020
Browse files
Regression test for pegasus bugfix (#6606)
parent
86c07e63
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
74 additions
and
64 deletions
+74
-64
src/transformers/configuration_pegasus.py
src/transformers/configuration_pegasus.py
+42
-4
src/transformers/convert_pegasus_tf_to_pytorch.py
src/transformers/convert_pegasus_tf_to_pytorch.py
+7
-43
src/transformers/modeling_pegasus.py
src/transformers/modeling_pegasus.py
+7
-0
tests/test_modeling_pegasus.py
tests/test_modeling_pegasus.py
+18
-17
No files found.
src/transformers/configuration_pegasus.py
View file @
5bf4465e
...
...
@@ -22,6 +22,7 @@ from .file_utils import add_start_docstrings_to_callable
logger
=
logging
.
getLogger
(
__name__
)
# These config values do not vary between checkpoints
DEFAULTS
=
dict
(
vocab_size
=
96103
,
max_position_embeddings
=
512
,
...
...
@@ -46,6 +47,47 @@ DEFAULTS = dict(
num_beams
=
8
,
activation_function
=
"relu"
,
)
# Config values that vary between checkpoints: for testing and conversion
max_gen_length
=
{
# See appendix C of paper
"xsum"
:
64
,
"cnn_dailymail"
:
128
,
"newsroom"
:
128
,
"wikihow"
:
256
,
"multi_news"
:
256
,
"reddit_tifu"
:
128
,
"big_patent"
:
256
,
"arxiv"
:
256
,
"pubmed"
:
256
,
"gigaword"
:
32
,
"aeslc"
:
32
,
"billsum"
:
256
,
"large"
:
256
,
# @sshleifer chose arbitrarily
}
max_model_length
=
{
"xsum"
:
512
,
"cnn_dailymail"
:
1024
,
"newsroom"
:
512
,
"wikihow"
:
512
,
"multi_news"
:
1024
,
"reddit_tifu"
:
512
,
"big_patent"
:
1024
,
"arxiv"
:
1024
,
"pubmed"
:
1024
,
"gigaword"
:
128
,
"aeslc"
:
512
,
"billsum"
:
1024
,
"large"
:
1024
,
}
expected_alpha
=
{
"multinews"
:
0.9
,
"wikihow"
:
0.6
,
"reddit_tifu"
:
0.6
,
"big_patent"
:
0.7
,
"gigaword"
:
0.6
,
"aeslc"
:
0.6
,
"billsum"
:
0.6
,
}
# otherwise 0.8
@
add_start_docstrings_to_callable
(
BART_CONFIG_ARGS_DOC
)
...
...
@@ -56,7 +98,3 @@ class PegasusConfig(BartConfig):
"""
model_type
=
"pegasus"
# The implementation of the config object is in BartConfig
@
property
def
default_config_parameters
(
self
):
return
DEFAULTS
src/transformers/convert_pegasus_tf_to_pytorch.py
View file @
5bf4465e
...
...
@@ -22,7 +22,7 @@ import torch
from
tqdm
import
tqdm
from
transformers
import
PegasusConfig
,
PegasusForConditionalGeneration
,
PegasusTokenizer
from
transformers.configuration_pegasus
import
DEFAULTS
from
transformers.configuration_pegasus
import
DEFAULTS
,
expected_alpha
,
max_gen_length
,
max_model_length
PATTERNS
=
[
...
...
@@ -52,47 +52,7 @@ def rename_state_dict_key(k):
# See appendix C of paper for all hyperparams
max_gen_length
=
{
# See appendix C of paper
"xsum"
:
64
,
"cnn_dailymail"
:
128
,
"newsroom"
:
128
,
"wikihow"
:
256
,
"multi_news"
:
256
,
"reddit_tifu"
:
128
,
"big_patent"
:
256
,
"arxiv"
:
256
,
"pubmed"
:
256
,
"gigaword"
:
32
,
"aeslc"
:
32
,
"billsum"
:
256
,
"large"
:
256
,
# @sshleifer chose arbitrarily
}
max_model_length
=
{
"xsum"
:
512
,
"cnn_dailymail"
:
1024
,
"newsroom"
:
512
,
"wikihow"
:
512
,
"multi_news"
:
1024
,
"reddit_tifu"
:
512
,
"big_patent"
:
1024
,
"arxiv"
:
1024
,
"pubmed"
:
1024
,
"gigaword"
:
128
,
"aeslc"
:
512
,
"billsum"
:
1024
,
"large"
:
1024
,
}
expected_alpha
=
{
"multinews"
:
0.9
,
"wikihow"
:
0.6
,
"reddit_tifu"
:
0.6
,
"big_patent"
:
0.7
,
"gigaword"
:
0.6
,
"aeslc"
:
0.6
,
"billsum"
:
0.6
,
}
# otherwise 0.8
# TODO(SS): one constant
...
...
@@ -151,7 +111,11 @@ def convert_pegasus_ckpt_to_pytorch(ckpt_path, save_dir):
# convert model
tf_weights
=
get_tf_weights_as_numpy
(
ckpt_path
)
cfg_updates
=
dict
(
max_length
=
max_gen_length
[
dataset
],
length_penalty
=
expected_alpha
.
get
(
dataset
,
0.8
))
cfg_updates
=
dict
(
max_length
=
max_gen_length
[
dataset
],
length_penalty
=
expected_alpha
.
get
(
dataset
,
0.8
),
max_position_embeddings
=
desired_max_model_length
,
)
torch_model
=
convert_pegasus_to_bart
(
tf_weights
,
cfg_updates
)
torch_model
.
save_pretrained
(
save_dir
)
...
...
src/transformers/modeling_pegasus.py
View file @
5bf4465e
...
...
@@ -23,6 +23,13 @@ from .modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration
@
add_start_docstrings
(
"The Pegasus Model for summarization "
,
BART_START_DOCSTRING
)
class
PegasusForConditionalGeneration
(
BartForConditionalGeneration
):
config_class
=
PegasusConfig
authorized_missing_keys
=
[
r
"final_logits_bias"
,
r
"encoder\.version"
,
r
"decoder\.version"
,
r
"model.encoder.embed_positions"
,
"model.decoder.embed_positions"
,
]
r
"""
Pytorch version of google's pegasus model for summarization.
Model API is identical to BartForConditionalGeneration.
...
...
tests/test_modeling_pegasus.py
View file @
5bf4465e
import
unittest
from
transformers
import
AutoConfig
,
is_torch_available
from
transformers
import
AutoConfig
,
AutoTokenizer
,
is_torch_available
from
transformers.configuration_pegasus
import
max_gen_length
,
max_model_length
from
transformers.file_utils
import
cached_property
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
...
...
@@ -50,28 +51,28 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
class
PegasusConfigTests
(
unittest
.
TestCase
):
def
test_all_config_max_lengths
(
self
):
expected_max_length
=
{
# See appendix C of paper
"xsum"
:
64
,
"cnn_dailymail"
:
128
,
"newsroom"
:
128
,
"wikihow"
:
256
,
"multi_news"
:
256
,
"reddit_tifu"
:
128
,
"big_patent"
:
256
,
"arxiv"
:
256
,
"pubmed"
:
256
,
"gigaword"
:
32
,
"aeslc"
:
32
,
"billsum"
:
256
,
}
failures
=
[]
pegasus_prefix
=
"google/pegasus"
for
dataset
,
max_len
in
expected_max
_length
.
items
():
for
dataset
,
max_len
in
max_gen
_length
.
items
():
mname
=
f
"
{
pegasus_prefix
}
-
{
dataset
}
"
cfg
=
AutoConfig
.
from_pretrained
(
mname
)
if
cfg
.
max_length
!=
max_len
:
failures
.
append
(
f
"config for
{
mname
}
had max_length:
{
cfg
.
max_length
}
, expected
{
max_len
}
"
)
if
cfg
.
max_position_embeddings
<
max_model_length
[
dataset
]:
# otherwise you get IndexError for e.g. position 513
# see https://github.com/huggingface/transformers/issues/6599
failures
.
append
(
f
"config for
{
mname
}
had max_position_embeddings:
{
cfg
.
max_position_embeddings
}
, expected
{
max_model_length
[
dataset
]
}
"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
mname
)
if
max_model_length
[
dataset
]
!=
tokenizer
.
model_max_length
:
failures
.
append
(
f
"tokenizer.model_max_length
{
tokenizer
.
model_max_length
}
expected
{
max_model_length
[
dataset
]
}
"
)
if
failures
==
[]:
return
# error
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment