test_glue_deebert.py 3.21 KB
Newer Older
1
2
3
4
5
6
7
import argparse
import logging
import sys
import unittest
from unittest.mock import patch

import run_glue_deebert
8
from transformers.testing_utils import slow
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23


logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()


def get_setup_file():
    parser = argparse.ArgumentParser()
    parser.add_argument("-f")
    args = parser.parse_args()
    return args.f


class DeeBertTests(unittest.TestCase):
Sam Shleifer's avatar
Sam Shleifer committed
24
    def setup(self) -> None:
25
26
27
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

Sam Shleifer's avatar
Sam Shleifer committed
28
29
30
    @slow
    def test_glue_deebert_train(self):

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        train_args = """
            run_glue_deebert.py
            --model_type roberta
            --model_name_or_path roberta-base
            --task_name MRPC
            --do_train
            --do_eval
            --do_lower_case
            --data_dir ./tests/fixtures/tests_samples/MRPC/
            --max_seq_length 128
            --per_gpu_eval_batch_size=1
            --per_gpu_train_batch_size=8
            --learning_rate 2e-4
            --num_train_epochs 3
            --overwrite_output_dir
            --seed 42
            --output_dir ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
            --plot_data_dir ./examples/deebert/results/
            --save_steps 0
            --overwrite_cache
            --eval_after_first_stage
            """.split()
Sam Shleifer's avatar
Sam Shleifer committed
53
54
55
56
        with patch.object(sys, "argv", train_args):
            result = run_glue_deebert.main()
            for value in result.values():
                self.assertGreaterEqual(value, 0.666)
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

        eval_args = """
            run_glue_deebert.py
            --model_type roberta
            --model_name_or_path ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
            --task_name MRPC
            --do_eval
            --do_lower_case
            --data_dir ./tests/fixtures/tests_samples/MRPC/
            --output_dir ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
            --plot_data_dir ./examples/deebert/results/
            --max_seq_length 128
            --eval_each_highway
            --eval_highway
            --overwrite_cache
            --per_gpu_eval_batch_size=1
            """.split()
Sam Shleifer's avatar
Sam Shleifer committed
74
75
76
77
        with patch.object(sys, "argv", eval_args):
            result = run_glue_deebert.main()
            for value in result.values():
                self.assertGreaterEqual(value, 0.666)
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

        entropy_eval_args = """
            run_glue_deebert.py
            --model_type roberta
            --model_name_or_path ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
            --task_name MRPC
            --do_eval
            --do_lower_case
            --data_dir ./tests/fixtures/tests_samples/MRPC/
            --output_dir ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
            --plot_data_dir ./examples/deebert/results/
            --max_seq_length 128
            --early_exit_entropy 0.1
            --eval_highway
            --overwrite_cache
            --per_gpu_eval_batch_size=1
            """.split()
        with patch.object(sys, "argv", entropy_eval_args):
            result = run_glue_deebert.main()
            for value in result.values():
Sam Shleifer's avatar
Sam Shleifer committed
98
                self.assertGreaterEqual(value, 0.666)