"docs/en/usage/output_files.md" did not exist on "cf704253f01908f5702650b71da7f65dc5d044e0"
test_modeling_pegasus.py 6.02 KB
Newer Older
1
2
import unittest

3
from transformers import AutoConfig, AutoTokenizer, is_torch_available
4
from transformers.file_utils import cached_property
Sylvain Gugger's avatar
Sylvain Gugger committed
5
from transformers.models.pegasus.configuration_pegasus import task_specific_params
6
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
7
from transformers.utils.logging import ERROR, set_verbosity
8
9

from .test_modeling_bart import PGE_ARTICLE
10
from .test_modeling_common import ModelTesterMixin
11
12
13
14
from .test_modeling_mbart import AbstractSeq2SeqIntegrationTest


if is_torch_available():
15
    from transformers import AutoModelForSeq2SeqLM, PegasusConfig, PegasusForConditionalGeneration
16
17
18

XSUM_ENTRY_LONGER = """ The London trio are up for best UK act and best album, as well as getting two nominations in the best song category."We got told like this morning 'Oh I think you're nominated'", said Dappy."And I was like 'Oh yeah, which one?' And now we've got nominated for four awards. I mean, wow!"Bandmate Fazer added: "We thought it's best of us to come down and mingle with everyone and say hello to the cameras. And now we find we've got four nominations."The band have two shots at the best song prize, getting the nod for their Tynchy Stryder collaboration Number One, and single Strong Again.Their album Uncle B will also go up against records by the likes of Beyonce and Kanye West.N-Dubz picked up the best newcomer Mobo in 2007, but female member Tulisa said they wouldn't be too disappointed if they didn't win this time around."At the end of the day we're grateful to be where we are in our careers."If it don't happen then it don't happen - live to fight another day and keep on making albums and hits for the fans."Dappy also revealed they could be performing live several times on the night.The group will be doing Number One and also a possible rendition of the War Child single, I Got Soul.The charity song is a  re-working of The Killers' All These Things That I've Done and is set to feature artists like Chipmunk, Ironik and Pixie Lott.This year's Mobos will be held outside of London for the first time, in Glasgow on 30 September.N-Dubz said they were looking forward to performing for their Scottish fans and boasted about their recent shows north of the border."We just done Edinburgh the other day," said Dappy."We smashed up an N-Dubz show over there. We done Aberdeen about three or four months ago - we smashed up that show over there! Everywhere we go we smash it up!" """

19
20
set_verbosity(ERROR)

21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
@require_torch
class ModelTester:
    def __init__(self, parent):
        self.config = PegasusConfig(
            vocab_size=99,
            d_model=24,
            encoder_layers=2,
            decoder_layers=2,
            encoder_attention_heads=2,
            decoder_attention_heads=2,
            encoder_ffn_dim=32,
            decoder_ffn_dim=32,
            max_position_embeddings=48,
            add_final_layer_norm=True,
        )

    def prepare_config_and_inputs_for_common(self):
        return self.config, {}


@require_torch
class SelectiveCommonTest(unittest.TestCase):
    all_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()

    test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save

    def setUp(self):
        self.model_tester = ModelTester(self)


52
@require_torch
53
54
@require_sentencepiece
@require_tokenizers
55
56
57
58
class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
    checkpoint_name = "google/pegasus-xsum"
    src_text = [PGE_ARTICLE, XSUM_ENTRY_LONGER]
    tgt_text = [
59
        "California's largest electricity provider has turned off power to hundreds of thousands of customers.",
Sam Shleifer's avatar
Sam Shleifer committed
60
        "Pop group N-Dubz have revealed they were surprised to get four nominations for this year's Mobo Awards.",
61
62
63
64
65
66
67
68
69
70
71
72
73
    ]

    @cached_property
    def model(self):
        return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint_name).to(torch_device)

    @slow
    def test_pegasus_xsum_summary(self):
        assert self.tokenizer.model_max_length == 512
        inputs = self.tokenizer(self.src_text, return_tensors="pt", truncation=True, max_length=512, padding=True).to(
            torch_device
        )
        assert inputs.input_ids.shape == (2, 421)
Sam Shleifer's avatar
Sam Shleifer committed
74
        translated_tokens = self.model.generate(**inputs, num_beams=2)
75
        decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
76
        assert self.tgt_text == decoded
77
78
79
80
81
82

        if "cuda" not in torch_device:
            return
        # Demonstrate fp16 issue, Contributions welcome!
        self.model.half()
        translated_tokens_fp16 = self.model.generate(**inputs, max_length=10)
83
84
85
86
87
        decoded_fp16 = self.tokenizer.batch_decode(translated_tokens_fp16, skip_special_tokens=True)
        assert decoded_fp16 == [
            "California's largest electricity provider has begun",
            "N-Dubz have revealed they were",
        ]
88
89
90


class PegasusConfigTests(unittest.TestCase):
91
92
93
    @slow
    def test_task_specific_params(self):
        """Test that task_specific params['summarization_xsum'] == config['pegasus_xsum'] """
94
95
        failures = []
        pegasus_prefix = "google/pegasus"
96
97
98
        n_prefix_chars = len("summarization_")
        for task, desired_settings in task_specific_params.items():
            dataset = task[n_prefix_chars:]
99
100
            mname = f"{pegasus_prefix}-{dataset}"
            cfg = AutoConfig.from_pretrained(mname)
101
102
103
104
            for k, v in desired_settings.items():
                actual_value = getattr(cfg, k)
                if actual_value != v:
                    failures.append(f"config for {mname} had {k}: {actual_value}, expected {v}")
105
            tokenizer = AutoTokenizer.from_pretrained(mname)
106
107
108
            n_pos_embeds = desired_settings["max_position_embeddings"]
            if n_pos_embeds != tokenizer.model_max_length:
                failures.append(f"tokenizer.model_max_length {tokenizer.model_max_length} expected {n_pos_embeds}")
109

110
111
        # error
        all_fails = "\n".join(failures)
112
        assert not failures, f"The following configs have unexpected settings: {all_fails}"