test_bart_examples.py 974 Bytes
Newer Older
1
import logging
2
import os
3
4
5
6
7
8
9
10
11
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch

from .evaluate_cnn import _run_generate


12
13
output_file_name = "output_bart_sum.txt"

14
15
16
17
18
19
20
21
22
23
24
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]

logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()


class TestBartExamples(unittest.TestCase):
    def test_bart_cnn_cli(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)
25
        tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
26
27
        with tmp.open("w") as f:
            f.write("\n".join(articles))
28
        testargs = ["evaluate_cnn.py", str(tmp), output_file_name]
29
30
        with patch.object(sys, "argv", testargs):
            _run_generate()
31
32
            self.assertTrue(Path(output_file_name).exists())
            os.remove(Path(output_file_name))