test_t5_examples.py 1.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import logging
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch

from .evaluate_cnn import run_generate


11
12
13
output_file_name = "output_t5_sum.txt"
score_file_name = "score_t5_sum.txt"

14
15
16
17
18
19
20
21
22
23
24
articles = ["New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]

logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()


class TestT5Examples(unittest.TestCase):
    def test_t5_cli(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)
25
        tmp = Path(tempfile.gettempdir()) / "utest_generations_t5_sum.hypo"
26
27
        with tmp.open("w") as f:
            f.write("\n".join(articles))
28
29
30
31

        output_file_name = Path(tempfile.gettempdir()) / "utest_output_t5_sum.hypo"
        score_file_name = Path(tempfile.gettempdir()) / "utest_score_t5_sum.hypo"

32
33
34
35
36
37
38
39
        testargs = [
            "evaluate_cnn.py",
            "patrickvonplaten/t5-tiny-random",
            str(tmp),
            str(output_file_name),
            str(tmp),
            str(score_file_name),
        ]
40

41
42
        with patch.object(sys, "argv", testargs):
            run_generate()
43
44
            self.assertTrue(Path(output_file_name).exists())
            self.assertTrue(Path(score_file_name).exists())