test_run_glue_with_pabee.py 1.34 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import argparse
import logging
import sys
import unittest
from unittest.mock import patch

import run_glue_with_pabee


logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()


def get_setup_file():
    parser = argparse.ArgumentParser()
    parser.add_argument("-f")
    args = parser.parse_args()
    return args.f


class PabeeTests(unittest.TestCase):
Lysandre Debut's avatar
Lysandre Debut committed
23
    @unittest.skip("Disable while Canwen investigates.")
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    def test_run_glue(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        testargs = """
            run_glue_with_pabee.py
            --model_type albert
            --model_name_or_path albert-base-v2
            --data_dir ./tests/fixtures/tests_samples/MRPC/
            --task_name mrpc
            --do_train
            --do_eval
            --output_dir ./tests/fixtures/tests_samples/temp_dir
            --per_gpu_train_batch_size=2
            --per_gpu_eval_batch_size=1
            --learning_rate=2e-5
            --max_steps=50
            --warmup_steps=2
            --overwrite_output_dir
            --seed=42
            --max_seq_length=128
            """.split()
        with patch.object(sys, "argv", testargs):
            result = run_glue_with_pabee.main()
            for value in result.values():
                self.assertGreaterEqual(value, 0.75)