test_examples.py 12.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
17

import argparse
18
import json
19
import logging
20
import os
Aymeric Augustin's avatar
Aymeric Augustin committed
21
import sys
22
from unittest.case import skip
Aymeric Augustin's avatar
Aymeric Augustin committed
23
from unittest.mock import patch
Aymeric Augustin's avatar
Aymeric Augustin committed
24

Stas Bekman's avatar
Stas Bekman committed
25
26
import torch

27
from transformers.file_utils import is_apex_available
28
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
29

30
31
32

SRC_DIRS = [
    os.path.join(os.path.dirname(__file__), dirname)
33
34
35
36
37
    for dirname in [
        "text-generation",
        "text-classification",
        "token-classification",
        "language-modeling",
38
        "multiple-choice",
39
        "question-answering",
Sylvain Gugger's avatar
Sylvain Gugger committed
40
41
        "summarization",
        "translation",
42
        "image-classification",
43
    ]
44
45
46
47
48
]
sys.path.extend(SRC_DIRS)


if SRC_DIRS is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
49
    import run_clm
50
51
    import run_generation
    import run_glue
52
    import run_image_classification
53
    import run_mlm
54
    import run_ner
Sylvain Gugger's avatar
Sylvain Gugger committed
55
    import run_qa as run_squad
56
    import run_summarization
57
    import run_swag
58
    import run_translation
Aymeric Augustin's avatar
Aymeric Augustin committed
59

60

61
62
63
logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()
64

65

66
67
def get_setup_file():
    parser = argparse.ArgumentParser()
68
    parser.add_argument("-f")
69
70
71
72
    args = parser.parse_args()
    return args.f


73
74
75
76
77
78
79
80
81
82
83
def get_results(output_dir):
    results = {}
    path = os.path.join(output_dir, "all_results.json")
    if os.path.exists(path):
        with open(path, "r") as f:
            results = json.load(f)
    else:
        raise ValueError(f"can't find {path}")
    return results


84
def is_cuda_and_apex_available():
85
86
87
88
    is_using_cuda = torch.cuda.is_available() and torch_device == "cuda"
    return is_using_cuda and is_apex_available()


89
class ExamplesTests(TestCasePlus):
90
91
92
93
    def test_run_glue(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

94
95
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
96
            run_glue.py
97
            --model_name_or_path distilbert-base-uncased
98
99
            --output_dir {tmp_dir}
            --overwrite_output_dir
Sylvain Gugger's avatar
Sylvain Gugger committed
100
101
            --train_file ./tests/fixtures/tests_samples/MRPC/train.csv
            --validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
102
103
            --do_train
            --do_eval
104
105
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
106
107
108
109
110
            --learning_rate=1e-4
            --max_steps=10
            --warmup_steps=2
            --seed=42
            --max_seq_length=128
111
            """.split()
112

113
        if is_cuda_and_apex_available():
114
            testargs.append("--fp16")
115

116
        with patch.object(sys, "argv", testargs):
117
118
            run_glue.main()
            result = get_results(tmp_dir)
119
            self.assertGreaterEqual(result["eval_accuracy"], 0.75)
120

Sylvain Gugger's avatar
Sylvain Gugger committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    def test_run_clm(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_clm.py
            --model_name_or_path distilgpt2
            --train_file ./tests/fixtures/sample_text.txt
            --validation_file ./tests/fixtures/sample_text.txt
            --do_train
            --do_eval
            --block_size 128
            --per_device_train_batch_size 5
            --per_device_eval_batch_size 5
            --num_train_epochs 2
            --output_dir {tmp_dir}
            --overwrite_output_dir
            """.split()

        if torch.cuda.device_count() > 1:
            # Skipping because there are not enough batches to train the model + would need a drop_last to work.
            return

        if torch_device != "cuda":
            testargs.append("--no_cuda")

        with patch.object(sys, "argv", testargs):
149
150
            run_clm.main()
            result = get_results(tmp_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
151
152
            self.assertLess(result["perplexity"], 100)

153
    def test_run_mlm(self):
Julien Chaumond's avatar
Julien Chaumond committed
154
155
156
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

157
158
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
159
            run_mlm.py
Julien Chaumond's avatar
Julien Chaumond committed
160
            --model_name_or_path distilroberta-base
161
162
            --train_file ./tests/fixtures/sample_text.txt
            --validation_file ./tests/fixtures/sample_text.txt
163
            --output_dir {tmp_dir}
Julien Chaumond's avatar
Julien Chaumond committed
164
165
166
            --overwrite_output_dir
            --do_train
            --do_eval
167
            --prediction_loss_only
Julien Chaumond's avatar
Julien Chaumond committed
168
            --num_train_epochs=1
169
        """.split()
170
171
172

        if torch_device != "cuda":
            testargs.append("--no_cuda")
173

Julien Chaumond's avatar
Julien Chaumond committed
174
        with patch.object(sys, "argv", testargs):
175
176
            run_mlm.main()
            result = get_results(tmp_dir)
177
            self.assertLess(result["perplexity"], 42)
Julien Chaumond's avatar
Julien Chaumond committed
178

179
180
181
182
    def test_run_ner(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

183
184
185
        # with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
        epochs = 7 if get_gpu_count() > 1 else 2

186
187
188
189
190
191
192
193
194
195
196
197
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_ner.py
            --model_name_or_path bert-base-uncased
            --train_file tests/fixtures/tests_samples/conll/sample.json
            --validation_file tests/fixtures/tests_samples/conll/sample.json
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --do_train
            --do_eval
            --warmup_steps=2
            --learning_rate=2e-4
Sylvain Gugger's avatar
Sylvain Gugger committed
198
199
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=2
200
            --num_train_epochs={epochs}
201
            --seed 7
202
203
204
205
206
207
        """.split()

        if torch_device != "cuda":
            testargs.append("--no_cuda")

        with patch.object(sys, "argv", testargs):
208
209
            run_ner.main()
            result = get_results(tmp_dir)
210
            self.assertGreaterEqual(result["eval_accuracy"], 0.75)
211
212
            self.assertLess(result["eval_loss"], 0.5)

213
214
215
216
    def test_run_squad(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

217
218
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
Russell Klopfer's avatar
Russell Klopfer committed
219
            run_qa.py
Sylvain Gugger's avatar
Sylvain Gugger committed
220
221
222
223
            --model_name_or_path bert-base-uncased
            --version_2_with_negative
            --train_file tests/fixtures/tests_samples/SQUAD/sample.json
            --validation_file tests/fixtures/tests_samples/SQUAD/sample.json
224
225
            --output_dir {tmp_dir}
            --overwrite_output_dir
226
227
228
229
230
            --max_steps=10
            --warmup_steps=2
            --do_train
            --do_eval
            --learning_rate=2e-4
Sylvain Gugger's avatar
Sylvain Gugger committed
231
232
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
233
234
        """.split()

235
        with patch.object(sys, "argv", testargs):
236
237
            run_squad.main()
            result = get_results(tmp_dir)
Russell Klopfer's avatar
Russell Klopfer committed
238
239
            self.assertGreaterEqual(result["eval_f1"], 30)
            self.assertGreaterEqual(result["eval_exact"], 30)
240

241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    def test_run_swag(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_swag.py
            --model_name_or_path bert-base-uncased
            --train_file tests/fixtures/tests_samples/swag/sample.json
            --validation_file tests/fixtures/tests_samples/swag/sample.json
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --max_steps=20
            --warmup_steps=2
            --do_train
            --do_eval
            --learning_rate=2e-4
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
        """.split()

        with patch.object(sys, "argv", testargs):
263
264
            run_swag.main()
            result = get_results(tmp_dir)
265
266
            self.assertGreaterEqual(result["eval_accuracy"], 0.8)

267
268
269
270
    def test_generation(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

271
        testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
272

273
        if is_cuda_and_apex_available():
274
275
276
277
278
279
            testargs.append("--fp16")

        model_type, model_name = (
            "--model_type=gpt2",
            "--model_name_or_path=sshleifer/tiny-gpt2",
        )
280
        with patch.object(sys, "argv", testargs + [model_type, model_name]):
281
            result = run_generation.main()
282
            self.assertGreaterEqual(len(result[0]), 10)
283
284

    @slow
285
    def test_run_summarization(self):
286
287
288
289
290
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
291
            run_summarization.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
            --model_name_or_path t5-small
            --train_file tests/fixtures/tests_samples/xsum/sample.json
            --validation_file tests/fixtures/tests_samples/xsum/sample.json
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --max_steps=50
            --warmup_steps=8
            --do_train
            --do_eval
            --learning_rate=2e-4
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
            --predict_with_generate
        """.split()

        with patch.object(sys, "argv", testargs):
308
            run_summarization.main()
309
            result = get_results(tmp_dir)
310
311
312
313
314
315
            self.assertGreaterEqual(result["eval_rouge1"], 10)
            self.assertGreaterEqual(result["eval_rouge2"], 2)
            self.assertGreaterEqual(result["eval_rougeL"], 7)
            self.assertGreaterEqual(result["eval_rougeLsum"], 7)

    @slow
316
    def test_run_translation(self):
317
318
319
320
321
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
322
            run_translation.py
323
            --model_name_or_path sshleifer/student_marian_en_ro_6_1
324
325
            --source_lang en
            --target_lang ro
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
            --train_file tests/fixtures/tests_samples/wmt16/sample.json
            --validation_file tests/fixtures/tests_samples/wmt16/sample.json
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --max_steps=50
            --warmup_steps=8
            --do_train
            --do_eval
            --learning_rate=3e-3
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
            --predict_with_generate
            --source_lang en_XX
            --target_lang ro_RO
        """.split()

        with patch.object(sys, "argv", testargs):
343
            run_translation.main()
344
            result = get_results(tmp_dir)
345
            self.assertGreaterEqual(result["eval_bleu"], 30)
346

347
    @skip("The test is failing as accuracy is 0, re-enable when fixed.")
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
    def test_run_image_classification(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_image_classification.py
            --output_dir {tmp_dir}
            --model_name_or_path google/vit-base-patch16-224-in21k
            --train_dir tests/fixtures/tests_samples/cats_and_dogs/
            --do_train
            --do_eval
            --learning_rate 2e-5
            --per_device_train_batch_size 2
            --per_device_eval_batch_size 1
            --remove_unused_columns False
            --overwrite_output_dir True
            --dataloader_num_workers 16
            --metric_for_best_model accuracy
            --max_steps 30
            --train_val_split 0.1
            --seed 7
        """.split()

        if is_cuda_and_apex_available():
            testargs.append("--fp16")

        with patch.object(sys, "argv", testargs):
            run_image_classification.main()
            result = get_results(tmp_dir)
            self.assertGreaterEqual(result["eval_accuracy"], 0.8)