test_t5_examples.py 1.38 KB
Newer Older
1
import logging
2
import os
3
4
5
6
7
8
9
10
11
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch

from .evaluate_wmt import run_generate


12
13
14
15
16
text = ["When Liana Barrientos was 23 years old, she got married in Westchester County."]
translation = ["Als Liana Barrientos 23 Jahre alt war, heiratete sie in Westchester County."]

output_file_name = "output_t5_trans.txt"
score_file_name = "score_t5_trans.txt"
17
18
19
20
21
22
23
24
25
26

logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()


class TestT5Examples(unittest.TestCase):
    def test_t5_cli(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)
27
28
29

        tmp_source = Path(tempfile.gettempdir()) / "utest_generations_t5_trans.hypo"
        with tmp_source.open("w") as f:
30
            f.write("\n".join(text))
31
32
33
34
35
36
37

        tmp_target = Path(tempfile.gettempdir()) / "utest_generations_t5_trans.target"
        with tmp_target.open("w") as f:
            f.write("\n".join(translation))

        testargs = ["evaluate_wmt.py", str(tmp_source), output_file_name, str(tmp_target), score_file_name]

38
39
        with patch.object(sys, "argv", testargs):
            run_generate()
40
41
42
43
            self.assertTrue(Path(output_file_name).exists())
            self.assertTrue(Path(score_file_name).exists())
            os.remove(Path(output_file_name))
            os.remove(Path(score_file_name))