test_examples.py 7.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
17

import argparse
18
import logging
19
import os
Aymeric Augustin's avatar
Aymeric Augustin committed
20
import sys
Aymeric Augustin's avatar
Aymeric Augustin committed
21
from unittest.mock import patch
Aymeric Augustin's avatar
Aymeric Augustin committed
22

Stas Bekman's avatar
Stas Bekman committed
23
24
import torch

25
from transformers.file_utils import is_apex_available
26
from transformers.testing_utils import TestCasePlus, require_torch_non_multi_gpu_but_fix_me, torch_device
27

28
29
30

SRC_DIRS = [
    os.path.join(os.path.dirname(__file__), dirname)
31
32
33
34
35
36
37
    for dirname in [
        "text-generation",
        "text-classification",
        "token-classification",
        "language-modeling",
        "question-answering",
    ]
38
39
40
41
42
]
sys.path.extend(SRC_DIRS)


if SRC_DIRS is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
43
    import run_clm
44
45
    import run_generation
    import run_glue
46
    import run_mlm
47
    import run_ner
Sylvain Gugger's avatar
Sylvain Gugger committed
48
    import run_qa as run_squad
Aymeric Augustin's avatar
Aymeric Augustin committed
49

50

51
52
53
logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()
54

55

56
57
def get_setup_file():
    parser = argparse.ArgumentParser()
58
    parser.add_argument("-f")
59
60
61
62
    args = parser.parse_args()
    return args.f


63
def is_cuda_and_apex_available():
64
65
66
67
    is_using_cuda = torch.cuda.is_available() and torch_device == "cuda"
    return is_using_cuda and is_apex_available()


68
class ExamplesTests(TestCasePlus):
69
    @require_torch_non_multi_gpu_but_fix_me
70
71
72
73
    def test_run_glue(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

74
75
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
76
            run_glue.py
77
            --model_name_or_path distilbert-base-uncased
78
79
            --output_dir {tmp_dir}
            --overwrite_output_dir
Sylvain Gugger's avatar
Sylvain Gugger committed
80
81
            --train_file ./tests/fixtures/tests_samples/MRPC/train.csv
            --validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
82
83
            --do_train
            --do_eval
84
85
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
86
87
88
89
90
            --learning_rate=1e-4
            --max_steps=10
            --warmup_steps=2
            --seed=42
            --max_seq_length=128
91
            """.split()
92

93
        if is_cuda_and_apex_available():
94
            testargs.append("--fp16")
95

96
        with patch.object(sys, "argv", testargs):
97
            result = run_glue.main()
98
            del result["eval_loss"]
99
100
            for value in result.values():
                self.assertGreaterEqual(value, 0.75)
101

102
    @require_torch_non_multi_gpu_but_fix_me
Sylvain Gugger's avatar
Sylvain Gugger committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    def test_run_clm(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_clm.py
            --model_name_or_path distilgpt2
            --train_file ./tests/fixtures/sample_text.txt
            --validation_file ./tests/fixtures/sample_text.txt
            --do_train
            --do_eval
            --block_size 128
            --per_device_train_batch_size 5
            --per_device_eval_batch_size 5
            --num_train_epochs 2
            --output_dir {tmp_dir}
            --overwrite_output_dir
            """.split()

        if torch.cuda.device_count() > 1:
            # Skipping because there are not enough batches to train the model + would need a drop_last to work.
            return

        if torch_device != "cuda":
            testargs.append("--no_cuda")

        with patch.object(sys, "argv", testargs):
            result = run_clm.main()
            self.assertLess(result["perplexity"], 100)

134
    @require_torch_non_multi_gpu_but_fix_me
135
    def test_run_mlm(self):
Julien Chaumond's avatar
Julien Chaumond committed
136
137
138
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

139
140
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
141
            run_mlm.py
Julien Chaumond's avatar
Julien Chaumond committed
142
            --model_name_or_path distilroberta-base
143
144
            --train_file ./tests/fixtures/sample_text.txt
            --validation_file ./tests/fixtures/sample_text.txt
145
            --output_dir {tmp_dir}
Julien Chaumond's avatar
Julien Chaumond committed
146
147
148
            --overwrite_output_dir
            --do_train
            --do_eval
149
            --prediction_loss_only
Julien Chaumond's avatar
Julien Chaumond committed
150
            --num_train_epochs=1
151
        """.split()
152
153
154

        if torch_device != "cuda":
            testargs.append("--no_cuda")
155

Julien Chaumond's avatar
Julien Chaumond committed
156
        with patch.object(sys, "argv", testargs):
157
            result = run_mlm.main()
158
            self.assertLess(result["perplexity"], 42)
Julien Chaumond's avatar
Julien Chaumond committed
159

160
    @require_torch_non_multi_gpu_but_fix_me
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    def test_run_ner(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_ner.py
            --model_name_or_path bert-base-uncased
            --train_file tests/fixtures/tests_samples/conll/sample.json
            --validation_file tests/fixtures/tests_samples/conll/sample.json
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --do_train
            --do_eval
            --warmup_steps=2
            --learning_rate=2e-4
Sylvain Gugger's avatar
Sylvain Gugger committed
177
178
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=2
179
180
181
182
183
184
185
186
187
188
189
190
            --num_train_epochs=2
        """.split()

        if torch_device != "cuda":
            testargs.append("--no_cuda")

        with patch.object(sys, "argv", testargs):
            result = run_ner.main()
            self.assertGreaterEqual(result["eval_accuracy_score"], 0.75)
            self.assertGreaterEqual(result["eval_precision"], 0.75)
            self.assertLess(result["eval_loss"], 0.5)

191
    @require_torch_non_multi_gpu_but_fix_me
192
193
194
195
    def test_run_squad(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

196
197
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
198
            run_squad.py
Sylvain Gugger's avatar
Sylvain Gugger committed
199
200
201
202
            --model_name_or_path bert-base-uncased
            --version_2_with_negative
            --train_file tests/fixtures/tests_samples/SQUAD/sample.json
            --validation_file tests/fixtures/tests_samples/SQUAD/sample.json
203
204
            --output_dir {tmp_dir}
            --overwrite_output_dir
205
206
207
208
209
            --max_steps=10
            --warmup_steps=2
            --do_train
            --do_eval
            --learning_rate=2e-4
Sylvain Gugger's avatar
Sylvain Gugger committed
210
211
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
212
213
        """.split()

214
        with patch.object(sys, "argv", testargs):
215
            result = run_squad.main()
Sylvain Gugger's avatar
Sylvain Gugger committed
216
217
            self.assertGreaterEqual(result["f1"], 30)
            self.assertGreaterEqual(result["exact"], 30)
218

219
    @require_torch_non_multi_gpu_but_fix_me
220
221
222
223
    def test_generation(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

224
        testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
225

226
        if is_cuda_and_apex_available():
227
228
229
230
231
232
            testargs.append("--fp16")

        model_type, model_name = (
            "--model_type=gpt2",
            "--model_name_or_path=sshleifer/tiny-gpt2",
        )
233
        with patch.object(sys, "argv", testargs + [model_type, model_name]):
234
            result = run_generation.main()
235
            self.assertGreaterEqual(len(result[0]), 10)