Unverified Commit 5bf4465e authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Regression test for pegasus bugfix (#6606)

parent 86c07e63
...@@ -22,6 +22,7 @@ from .file_utils import add_start_docstrings_to_callable ...@@ -22,6 +22,7 @@ from .file_utils import add_start_docstrings_to_callable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# These config values do not vary between checkpoints
DEFAULTS = dict( DEFAULTS = dict(
vocab_size=96103, vocab_size=96103,
max_position_embeddings=512, max_position_embeddings=512,
...@@ -46,6 +47,47 @@ DEFAULTS = dict( ...@@ -46,6 +47,47 @@ DEFAULTS = dict(
num_beams=8, num_beams=8,
activation_function="relu", activation_function="relu",
) )
# Config values that vary between checkpoints: for testing and conversion
max_gen_length = {
# See appendix C of paper
"xsum": 64,
"cnn_dailymail": 128,
"newsroom": 128,
"wikihow": 256,
"multi_news": 256,
"reddit_tifu": 128,
"big_patent": 256,
"arxiv": 256,
"pubmed": 256,
"gigaword": 32,
"aeslc": 32,
"billsum": 256,
"large": 256, # @sshleifer chose arbitrarily
}
max_model_length = {
"xsum": 512,
"cnn_dailymail": 1024,
"newsroom": 512,
"wikihow": 512,
"multi_news": 1024,
"reddit_tifu": 512,
"big_patent": 1024,
"arxiv": 1024,
"pubmed": 1024,
"gigaword": 128,
"aeslc": 512,
"billsum": 1024,
"large": 1024,
}
expected_alpha = {
"multinews": 0.9,
"wikihow": 0.6,
"reddit_tifu": 0.6,
"big_patent": 0.7,
"gigaword": 0.6,
"aeslc": 0.6,
"billsum": 0.6,
} # otherwise 0.8
@add_start_docstrings_to_callable(BART_CONFIG_ARGS_DOC) @add_start_docstrings_to_callable(BART_CONFIG_ARGS_DOC)
...@@ -56,7 +98,3 @@ class PegasusConfig(BartConfig): ...@@ -56,7 +98,3 @@ class PegasusConfig(BartConfig):
""" """
model_type = "pegasus" model_type = "pegasus"
# The implementation of the config object is in BartConfig # The implementation of the config object is in BartConfig
@property
def default_config_parameters(self):
return DEFAULTS
...@@ -22,7 +22,7 @@ import torch ...@@ -22,7 +22,7 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
from transformers.configuration_pegasus import DEFAULTS from transformers.configuration_pegasus import DEFAULTS, expected_alpha, max_gen_length, max_model_length
PATTERNS = [ PATTERNS = [
...@@ -52,47 +52,7 @@ def rename_state_dict_key(k): ...@@ -52,47 +52,7 @@ def rename_state_dict_key(k):
# See appendix C of paper for all hyperparams # See appendix C of paper for all hyperparams
max_gen_length = {
# See appendix C of paper
"xsum": 64,
"cnn_dailymail": 128,
"newsroom": 128,
"wikihow": 256,
"multi_news": 256,
"reddit_tifu": 128,
"big_patent": 256,
"arxiv": 256,
"pubmed": 256,
"gigaword": 32,
"aeslc": 32,
"billsum": 256,
"large": 256, # @sshleifer chose arbitrarily
}
max_model_length = {
"xsum": 512,
"cnn_dailymail": 1024,
"newsroom": 512,
"wikihow": 512,
"multi_news": 1024,
"reddit_tifu": 512,
"big_patent": 1024,
"arxiv": 1024,
"pubmed": 1024,
"gigaword": 128,
"aeslc": 512,
"billsum": 1024,
"large": 1024,
}
expected_alpha = {
"multinews": 0.9,
"wikihow": 0.6,
"reddit_tifu": 0.6,
"big_patent": 0.7,
"gigaword": 0.6,
"aeslc": 0.6,
"billsum": 0.6,
} # otherwise 0.8
# TODO(SS): one constant # TODO(SS): one constant
...@@ -151,7 +111,11 @@ def convert_pegasus_ckpt_to_pytorch(ckpt_path, save_dir): ...@@ -151,7 +111,11 @@ def convert_pegasus_ckpt_to_pytorch(ckpt_path, save_dir):
# convert model # convert model
tf_weights = get_tf_weights_as_numpy(ckpt_path) tf_weights = get_tf_weights_as_numpy(ckpt_path)
cfg_updates = dict(max_length=max_gen_length[dataset], length_penalty=expected_alpha.get(dataset, 0.8)) cfg_updates = dict(
max_length=max_gen_length[dataset],
length_penalty=expected_alpha.get(dataset, 0.8),
max_position_embeddings=desired_max_model_length,
)
torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates) torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates)
torch_model.save_pretrained(save_dir) torch_model.save_pretrained(save_dir)
......
...@@ -23,6 +23,13 @@ from .modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration ...@@ -23,6 +23,13 @@ from .modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration
@add_start_docstrings("The Pegasus Model for summarization ", BART_START_DOCSTRING) @add_start_docstrings("The Pegasus Model for summarization ", BART_START_DOCSTRING)
class PegasusForConditionalGeneration(BartForConditionalGeneration): class PegasusForConditionalGeneration(BartForConditionalGeneration):
config_class = PegasusConfig config_class = PegasusConfig
authorized_missing_keys = [
r"final_logits_bias",
r"encoder\.version",
r"decoder\.version",
r"model.encoder.embed_positions",
"model.decoder.embed_positions",
]
r""" r"""
Pytorch version of google's pegasus model for summarization. Pytorch version of google's pegasus model for summarization.
Model API is identical to BartForConditionalGeneration. Model API is identical to BartForConditionalGeneration.
......
import unittest import unittest
from transformers import AutoConfig, is_torch_available from transformers import AutoConfig, AutoTokenizer, is_torch_available
from transformers.configuration_pegasus import max_gen_length, max_model_length
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
...@@ -50,28 +51,28 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest): ...@@ -50,28 +51,28 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
class PegasusConfigTests(unittest.TestCase): class PegasusConfigTests(unittest.TestCase):
def test_all_config_max_lengths(self): def test_all_config_max_lengths(self):
expected_max_length = {
# See appendix C of paper
"xsum": 64,
"cnn_dailymail": 128,
"newsroom": 128,
"wikihow": 256,
"multi_news": 256,
"reddit_tifu": 128,
"big_patent": 256,
"arxiv": 256,
"pubmed": 256,
"gigaword": 32,
"aeslc": 32,
"billsum": 256,
}
failures = [] failures = []
pegasus_prefix = "google/pegasus" pegasus_prefix = "google/pegasus"
for dataset, max_len in expected_max_length.items(): for dataset, max_len in max_gen_length.items():
mname = f"{pegasus_prefix}-{dataset}" mname = f"{pegasus_prefix}-{dataset}"
cfg = AutoConfig.from_pretrained(mname) cfg = AutoConfig.from_pretrained(mname)
if cfg.max_length != max_len: if cfg.max_length != max_len:
failures.append(f"config for {mname} had max_length: {cfg.max_length}, expected {max_len}") failures.append(f"config for {mname} had max_length: {cfg.max_length}, expected {max_len}")
if cfg.max_position_embeddings < max_model_length[dataset]:
# otherwise you get IndexError for e.g. position 513
# see https://github.com/huggingface/transformers/issues/6599
failures.append(
f"config for {mname} had max_position_embeddings: {cfg.max_position_embeddings}, expected {max_model_length[dataset]}"
)
tokenizer = AutoTokenizer.from_pretrained(mname)
if max_model_length[dataset] != tokenizer.model_max_length:
failures.append(
f"tokenizer.model_max_length {tokenizer.model_max_length} expected {max_model_length[dataset]}"
)
if failures == []: if failures == []:
return return
# error # error
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment