Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5bf4465e
Unverified
Commit
5bf4465e
authored
Aug 20, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 20, 2020
Browse files
Regression test for pegasus bugfix (#6606)
parent
86c07e63
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
74 additions
and
64 deletions
+74
-64
src/transformers/configuration_pegasus.py
src/transformers/configuration_pegasus.py
+42
-4
src/transformers/convert_pegasus_tf_to_pytorch.py
src/transformers/convert_pegasus_tf_to_pytorch.py
+7
-43
src/transformers/modeling_pegasus.py
src/transformers/modeling_pegasus.py
+7
-0
tests/test_modeling_pegasus.py
tests/test_modeling_pegasus.py
+18
-17
No files found.
src/transformers/configuration_pegasus.py
View file @
5bf4465e
...
@@ -22,6 +22,7 @@ from .file_utils import add_start_docstrings_to_callable
...
@@ -22,6 +22,7 @@ from .file_utils import add_start_docstrings_to_callable
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# These config values do not vary between checkpoints
DEFAULTS
=
dict
(
DEFAULTS
=
dict
(
vocab_size
=
96103
,
vocab_size
=
96103
,
max_position_embeddings
=
512
,
max_position_embeddings
=
512
,
...
@@ -46,6 +47,47 @@ DEFAULTS = dict(
...
@@ -46,6 +47,47 @@ DEFAULTS = dict(
num_beams
=
8
,
num_beams
=
8
,
activation_function
=
"relu"
,
activation_function
=
"relu"
,
)
)
# Config values that vary between checkpoints: for testing and conversion
max_gen_length
=
{
# See appendix C of paper
"xsum"
:
64
,
"cnn_dailymail"
:
128
,
"newsroom"
:
128
,
"wikihow"
:
256
,
"multi_news"
:
256
,
"reddit_tifu"
:
128
,
"big_patent"
:
256
,
"arxiv"
:
256
,
"pubmed"
:
256
,
"gigaword"
:
32
,
"aeslc"
:
32
,
"billsum"
:
256
,
"large"
:
256
,
# @sshleifer chose arbitrarily
}
max_model_length
=
{
"xsum"
:
512
,
"cnn_dailymail"
:
1024
,
"newsroom"
:
512
,
"wikihow"
:
512
,
"multi_news"
:
1024
,
"reddit_tifu"
:
512
,
"big_patent"
:
1024
,
"arxiv"
:
1024
,
"pubmed"
:
1024
,
"gigaword"
:
128
,
"aeslc"
:
512
,
"billsum"
:
1024
,
"large"
:
1024
,
}
expected_alpha
=
{
"multinews"
:
0.9
,
"wikihow"
:
0.6
,
"reddit_tifu"
:
0.6
,
"big_patent"
:
0.7
,
"gigaword"
:
0.6
,
"aeslc"
:
0.6
,
"billsum"
:
0.6
,
}
# otherwise 0.8
@
add_start_docstrings_to_callable
(
BART_CONFIG_ARGS_DOC
)
@
add_start_docstrings_to_callable
(
BART_CONFIG_ARGS_DOC
)
...
@@ -56,7 +98,3 @@ class PegasusConfig(BartConfig):
...
@@ -56,7 +98,3 @@ class PegasusConfig(BartConfig):
"""
"""
model_type
=
"pegasus"
model_type
=
"pegasus"
# The implementation of the config object is in BartConfig
# The implementation of the config object is in BartConfig
@
property
def
default_config_parameters
(
self
):
return
DEFAULTS
src/transformers/convert_pegasus_tf_to_pytorch.py
View file @
5bf4465e
...
@@ -22,7 +22,7 @@ import torch
...
@@ -22,7 +22,7 @@ import torch
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
transformers
import
PegasusConfig
,
PegasusForConditionalGeneration
,
PegasusTokenizer
from
transformers
import
PegasusConfig
,
PegasusForConditionalGeneration
,
PegasusTokenizer
from
transformers.configuration_pegasus
import
DEFAULTS
from
transformers.configuration_pegasus
import
DEFAULTS
,
expected_alpha
,
max_gen_length
,
max_model_length
PATTERNS
=
[
PATTERNS
=
[
...
@@ -52,47 +52,7 @@ def rename_state_dict_key(k):
...
@@ -52,47 +52,7 @@ def rename_state_dict_key(k):
# See appendix C of paper for all hyperparams
# See appendix C of paper for all hyperparams
max_gen_length
=
{
# See appendix C of paper
"xsum"
:
64
,
"cnn_dailymail"
:
128
,
"newsroom"
:
128
,
"wikihow"
:
256
,
"multi_news"
:
256
,
"reddit_tifu"
:
128
,
"big_patent"
:
256
,
"arxiv"
:
256
,
"pubmed"
:
256
,
"gigaword"
:
32
,
"aeslc"
:
32
,
"billsum"
:
256
,
"large"
:
256
,
# @sshleifer chose arbitrarily
}
max_model_length
=
{
"xsum"
:
512
,
"cnn_dailymail"
:
1024
,
"newsroom"
:
512
,
"wikihow"
:
512
,
"multi_news"
:
1024
,
"reddit_tifu"
:
512
,
"big_patent"
:
1024
,
"arxiv"
:
1024
,
"pubmed"
:
1024
,
"gigaword"
:
128
,
"aeslc"
:
512
,
"billsum"
:
1024
,
"large"
:
1024
,
}
expected_alpha
=
{
"multinews"
:
0.9
,
"wikihow"
:
0.6
,
"reddit_tifu"
:
0.6
,
"big_patent"
:
0.7
,
"gigaword"
:
0.6
,
"aeslc"
:
0.6
,
"billsum"
:
0.6
,
}
# otherwise 0.8
# TODO(SS): one constant
# TODO(SS): one constant
...
@@ -151,7 +111,11 @@ def convert_pegasus_ckpt_to_pytorch(ckpt_path, save_dir):
...
@@ -151,7 +111,11 @@ def convert_pegasus_ckpt_to_pytorch(ckpt_path, save_dir):
# convert model
# convert model
tf_weights
=
get_tf_weights_as_numpy
(
ckpt_path
)
tf_weights
=
get_tf_weights_as_numpy
(
ckpt_path
)
cfg_updates
=
dict
(
max_length
=
max_gen_length
[
dataset
],
length_penalty
=
expected_alpha
.
get
(
dataset
,
0.8
))
cfg_updates
=
dict
(
max_length
=
max_gen_length
[
dataset
],
length_penalty
=
expected_alpha
.
get
(
dataset
,
0.8
),
max_position_embeddings
=
desired_max_model_length
,
)
torch_model
=
convert_pegasus_to_bart
(
tf_weights
,
cfg_updates
)
torch_model
=
convert_pegasus_to_bart
(
tf_weights
,
cfg_updates
)
torch_model
.
save_pretrained
(
save_dir
)
torch_model
.
save_pretrained
(
save_dir
)
...
...
src/transformers/modeling_pegasus.py
View file @
5bf4465e
...
@@ -23,6 +23,13 @@ from .modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration
...
@@ -23,6 +23,13 @@ from .modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration
@
add_start_docstrings
(
"The Pegasus Model for summarization "
,
BART_START_DOCSTRING
)
@
add_start_docstrings
(
"The Pegasus Model for summarization "
,
BART_START_DOCSTRING
)
class
PegasusForConditionalGeneration
(
BartForConditionalGeneration
):
class
PegasusForConditionalGeneration
(
BartForConditionalGeneration
):
config_class
=
PegasusConfig
config_class
=
PegasusConfig
authorized_missing_keys
=
[
r
"final_logits_bias"
,
r
"encoder\.version"
,
r
"decoder\.version"
,
r
"model.encoder.embed_positions"
,
"model.decoder.embed_positions"
,
]
r
"""
r
"""
Pytorch version of google's pegasus model for summarization.
Pytorch version of google's pegasus model for summarization.
Model API is identical to BartForConditionalGeneration.
Model API is identical to BartForConditionalGeneration.
...
...
tests/test_modeling_pegasus.py
View file @
5bf4465e
import
unittest
import
unittest
from
transformers
import
AutoConfig
,
is_torch_available
from
transformers
import
AutoConfig
,
AutoTokenizer
,
is_torch_available
from
transformers.configuration_pegasus
import
max_gen_length
,
max_model_length
from
transformers.file_utils
import
cached_property
from
transformers.file_utils
import
cached_property
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
...
@@ -50,28 +51,28 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
...
@@ -50,28 +51,28 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
class
PegasusConfigTests
(
unittest
.
TestCase
):
class
PegasusConfigTests
(
unittest
.
TestCase
):
def
test_all_config_max_lengths
(
self
):
def
test_all_config_max_lengths
(
self
):
expected_max_length
=
{
# See appendix C of paper
"xsum"
:
64
,
"cnn_dailymail"
:
128
,
"newsroom"
:
128
,
"wikihow"
:
256
,
"multi_news"
:
256
,
"reddit_tifu"
:
128
,
"big_patent"
:
256
,
"arxiv"
:
256
,
"pubmed"
:
256
,
"gigaword"
:
32
,
"aeslc"
:
32
,
"billsum"
:
256
,
}
failures
=
[]
failures
=
[]
pegasus_prefix
=
"google/pegasus"
pegasus_prefix
=
"google/pegasus"
for
dataset
,
max_len
in
expected_max
_length
.
items
():
for
dataset
,
max_len
in
max_gen
_length
.
items
():
mname
=
f
"
{
pegasus_prefix
}
-
{
dataset
}
"
mname
=
f
"
{
pegasus_prefix
}
-
{
dataset
}
"
cfg
=
AutoConfig
.
from_pretrained
(
mname
)
cfg
=
AutoConfig
.
from_pretrained
(
mname
)
if
cfg
.
max_length
!=
max_len
:
if
cfg
.
max_length
!=
max_len
:
failures
.
append
(
f
"config for
{
mname
}
had max_length:
{
cfg
.
max_length
}
, expected
{
max_len
}
"
)
failures
.
append
(
f
"config for
{
mname
}
had max_length:
{
cfg
.
max_length
}
, expected
{
max_len
}
"
)
if
cfg
.
max_position_embeddings
<
max_model_length
[
dataset
]:
# otherwise you get IndexError for e.g. position 513
# see https://github.com/huggingface/transformers/issues/6599
failures
.
append
(
f
"config for
{
mname
}
had max_position_embeddings:
{
cfg
.
max_position_embeddings
}
, expected
{
max_model_length
[
dataset
]
}
"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
mname
)
if
max_model_length
[
dataset
]
!=
tokenizer
.
model_max_length
:
failures
.
append
(
f
"tokenizer.model_max_length
{
tokenizer
.
model_max_length
}
expected
{
max_model_length
[
dataset
]
}
"
)
if
failures
==
[]:
if
failures
==
[]:
return
return
# error
# error
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment