test_t5_examples.py 1.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import logging
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch

from .evaluate_wmt import run_generate


11
12
13
14
15
text = ["When Liana Barrientos was 23 years old, she got married in Westchester County."]
translation = ["Als Liana Barrientos 23 Jahre alt war, heiratete sie in Westchester County."]

output_file_name = "output_t5_trans.txt"
score_file_name = "score_t5_trans.txt"
16
17
18
19
20
21
22
23
24
25

logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()


class TestT5Examples(unittest.TestCase):
    def test_t5_cli(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)
26
27
28

        tmp_source = Path(tempfile.gettempdir()) / "utest_generations_t5_trans.hypo"
        with tmp_source.open("w") as f:
29
            f.write("\n".join(text))
30
31
32
33
34

        tmp_target = Path(tempfile.gettempdir()) / "utest_generations_t5_trans.target"
        with tmp_target.open("w") as f:
            f.write("\n".join(translation))

35
36
37
38
39
        output_file_name = Path(tempfile.gettempdir()) / "utest_output_trans.hypo"
        score_file_name = Path(tempfile.gettempdir()) / "utest_score.hypo"

        testargs = [
            "evaluate_wmt.py",
40
            "patrickvonplaten/t5-tiny-random",
41
42
43
44
45
            str(tmp_source),
            str(output_file_name),
            str(tmp_target),
            str(score_file_name),
        ]
46

47
48
        with patch.object(sys, "argv", testargs):
            run_generate()
49
50
            self.assertTrue(Path(output_file_name).exists())
            self.assertTrue(Path(score_file_name).exists())