test_examples.py 14.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
17

import argparse
18
import json
19
import logging
20
import os
Aymeric Augustin's avatar
Aymeric Augustin committed
21
import sys
Aymeric Augustin's avatar
Aymeric Augustin committed
22
from unittest.mock import patch
Aymeric Augustin's avatar
Aymeric Augustin committed
23

Stas Bekman's avatar
Stas Bekman committed
24
25
import torch

26
from transformers.file_utils import is_apex_available
27
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
28

29
30
31

SRC_DIRS = [
    os.path.join(os.path.dirname(__file__), dirname)
32
33
34
35
36
    for dirname in [
        "text-generation",
        "text-classification",
        "token-classification",
        "language-modeling",
37
        "multiple-choice",
38
        "question-answering",
Sylvain Gugger's avatar
Sylvain Gugger committed
39
40
        "summarization",
        "translation",
41
        "image-classification",
42
        "speech-recognition",
43
        "audio-classification",
44
    ]
45
46
47
48
49
]
sys.path.extend(SRC_DIRS)


if SRC_DIRS is not None:
50
    import run_audio_classification
Sylvain Gugger's avatar
Sylvain Gugger committed
51
    import run_clm
52
53
    import run_generation
    import run_glue
54
    import run_image_classification
55
    import run_mlm
56
    import run_ner
Sylvain Gugger's avatar
Sylvain Gugger committed
57
    import run_qa as run_squad
58
    import run_speech_recognition_ctc
59
    import run_summarization
60
    import run_swag
61
    import run_translation
Aymeric Augustin's avatar
Aymeric Augustin committed
62

63

64
65
66
logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()
67

68

69
70
def get_setup_file():
    parser = argparse.ArgumentParser()
71
    parser.add_argument("-f")
72
73
74
75
    args = parser.parse_args()
    return args.f


76
77
78
79
80
81
82
83
84
85
86
def get_results(output_dir):
    results = {}
    path = os.path.join(output_dir, "all_results.json")
    if os.path.exists(path):
        with open(path, "r") as f:
            results = json.load(f)
    else:
        raise ValueError(f"can't find {path}")
    return results


87
def is_cuda_and_apex_available():
88
89
90
91
    is_using_cuda = torch.cuda.is_available() and torch_device == "cuda"
    return is_using_cuda and is_apex_available()


92
class ExamplesTests(TestCasePlus):
93
94
95
96
    def test_run_glue(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

97
98
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
99
            run_glue.py
100
            --model_name_or_path distilbert-base-uncased
101
102
            --output_dir {tmp_dir}
            --overwrite_output_dir
Sylvain Gugger's avatar
Sylvain Gugger committed
103
104
            --train_file ./tests/fixtures/tests_samples/MRPC/train.csv
            --validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
105
106
            --do_train
            --do_eval
107
108
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
109
110
111
112
113
            --learning_rate=1e-4
            --max_steps=10
            --warmup_steps=2
            --seed=42
            --max_seq_length=128
114
            """.split()
115

116
        if is_cuda_and_apex_available():
117
            testargs.append("--fp16")
118

119
        with patch.object(sys, "argv", testargs):
120
121
            run_glue.main()
            result = get_results(tmp_dir)
122
            self.assertGreaterEqual(result["eval_accuracy"], 0.75)
123

Sylvain Gugger's avatar
Sylvain Gugger committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    def test_run_clm(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_clm.py
            --model_name_or_path distilgpt2
            --train_file ./tests/fixtures/sample_text.txt
            --validation_file ./tests/fixtures/sample_text.txt
            --do_train
            --do_eval
            --block_size 128
            --per_device_train_batch_size 5
            --per_device_eval_batch_size 5
            --num_train_epochs 2
            --output_dir {tmp_dir}
            --overwrite_output_dir
            """.split()

        if torch.cuda.device_count() > 1:
            # Skipping because there are not enough batches to train the model + would need a drop_last to work.
            return

        if torch_device != "cuda":
            testargs.append("--no_cuda")

        with patch.object(sys, "argv", testargs):
152
153
            run_clm.main()
            result = get_results(tmp_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
154
155
            self.assertLess(result["perplexity"], 100)

156
    def test_run_mlm(self):
Julien Chaumond's avatar
Julien Chaumond committed
157
158
159
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

160
161
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
162
            run_mlm.py
Julien Chaumond's avatar
Julien Chaumond committed
163
            --model_name_or_path distilroberta-base
164
165
            --train_file ./tests/fixtures/sample_text.txt
            --validation_file ./tests/fixtures/sample_text.txt
166
            --output_dir {tmp_dir}
Julien Chaumond's avatar
Julien Chaumond committed
167
168
169
            --overwrite_output_dir
            --do_train
            --do_eval
170
            --prediction_loss_only
Julien Chaumond's avatar
Julien Chaumond committed
171
            --num_train_epochs=1
172
        """.split()
173
174
175

        if torch_device != "cuda":
            testargs.append("--no_cuda")
176

Julien Chaumond's avatar
Julien Chaumond committed
177
        with patch.object(sys, "argv", testargs):
178
179
            run_mlm.main()
            result = get_results(tmp_dir)
180
            self.assertLess(result["perplexity"], 42)
Julien Chaumond's avatar
Julien Chaumond committed
181

182
183
184
185
    def test_run_ner(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

186
187
188
        # with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
        epochs = 7 if get_gpu_count() > 1 else 2

189
190
191
192
193
194
195
196
197
198
199
200
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_ner.py
            --model_name_or_path bert-base-uncased
            --train_file tests/fixtures/tests_samples/conll/sample.json
            --validation_file tests/fixtures/tests_samples/conll/sample.json
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --do_train
            --do_eval
            --warmup_steps=2
            --learning_rate=2e-4
Sylvain Gugger's avatar
Sylvain Gugger committed
201
202
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=2
203
            --num_train_epochs={epochs}
204
            --seed 7
205
206
207
208
209
210
        """.split()

        if torch_device != "cuda":
            testargs.append("--no_cuda")

        with patch.object(sys, "argv", testargs):
211
212
            run_ner.main()
            result = get_results(tmp_dir)
213
            self.assertGreaterEqual(result["eval_accuracy"], 0.75)
214
215
            self.assertLess(result["eval_loss"], 0.5)

216
217
218
219
    def test_run_squad(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

220
221
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
Russell Klopfer's avatar
Russell Klopfer committed
222
            run_qa.py
Sylvain Gugger's avatar
Sylvain Gugger committed
223
224
225
226
            --model_name_or_path bert-base-uncased
            --version_2_with_negative
            --train_file tests/fixtures/tests_samples/SQUAD/sample.json
            --validation_file tests/fixtures/tests_samples/SQUAD/sample.json
227
228
            --output_dir {tmp_dir}
            --overwrite_output_dir
229
230
231
232
233
            --max_steps=10
            --warmup_steps=2
            --do_train
            --do_eval
            --learning_rate=2e-4
Sylvain Gugger's avatar
Sylvain Gugger committed
234
235
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
236
237
        """.split()

238
        with patch.object(sys, "argv", testargs):
239
240
            run_squad.main()
            result = get_results(tmp_dir)
Russell Klopfer's avatar
Russell Klopfer committed
241
242
            self.assertGreaterEqual(result["eval_f1"], 30)
            self.assertGreaterEqual(result["eval_exact"], 30)
243

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    def test_run_swag(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_swag.py
            --model_name_or_path bert-base-uncased
            --train_file tests/fixtures/tests_samples/swag/sample.json
            --validation_file tests/fixtures/tests_samples/swag/sample.json
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --max_steps=20
            --warmup_steps=2
            --do_train
            --do_eval
            --learning_rate=2e-4
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
        """.split()

        with patch.object(sys, "argv", testargs):
266
267
            run_swag.main()
            result = get_results(tmp_dir)
268
269
            self.assertGreaterEqual(result["eval_accuracy"], 0.8)

270
271
272
273
    def test_generation(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

274
        testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
275

276
        if is_cuda_and_apex_available():
277
278
279
280
281
282
            testargs.append("--fp16")

        model_type, model_name = (
            "--model_type=gpt2",
            "--model_name_or_path=sshleifer/tiny-gpt2",
        )
283
        with patch.object(sys, "argv", testargs + [model_type, model_name]):
284
            result = run_generation.main()
285
            self.assertGreaterEqual(len(result[0]), 10)
286
287

    @slow
288
    def test_run_summarization(self):
289
290
291
292
293
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
294
            run_summarization.py
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
            --model_name_or_path t5-small
            --train_file tests/fixtures/tests_samples/xsum/sample.json
            --validation_file tests/fixtures/tests_samples/xsum/sample.json
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --max_steps=50
            --warmup_steps=8
            --do_train
            --do_eval
            --learning_rate=2e-4
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
            --predict_with_generate
        """.split()

        with patch.object(sys, "argv", testargs):
311
            run_summarization.main()
312
            result = get_results(tmp_dir)
313
314
315
316
317
318
            self.assertGreaterEqual(result["eval_rouge1"], 10)
            self.assertGreaterEqual(result["eval_rouge2"], 2)
            self.assertGreaterEqual(result["eval_rougeL"], 7)
            self.assertGreaterEqual(result["eval_rougeLsum"], 7)

    @slow
319
    def test_run_translation(self):
320
321
322
323
324
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
325
            run_translation.py
326
            --model_name_or_path sshleifer/student_marian_en_ro_6_1
327
328
            --source_lang en
            --target_lang ro
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
            --train_file tests/fixtures/tests_samples/wmt16/sample.json
            --validation_file tests/fixtures/tests_samples/wmt16/sample.json
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --max_steps=50
            --warmup_steps=8
            --do_train
            --do_eval
            --learning_rate=3e-3
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
            --predict_with_generate
            --source_lang en_XX
            --target_lang ro_RO
        """.split()

        with patch.object(sys, "argv", testargs):
346
            run_translation.main()
347
            result = get_results(tmp_dir)
348
            self.assertGreaterEqual(result["eval_bleu"], 30)
349
350
351
352
353
354
355
356
357
358

    def test_run_image_classification(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_image_classification.py
            --output_dir {tmp_dir}
            --model_name_or_path google/vit-base-patch16-224-in21k
359
            --dataset_name hf-internal-testing/cats_vs_dogs_sample
360
361
            --do_train
            --do_eval
362
            --learning_rate 1e-4
363
364
365
366
367
368
            --per_device_train_batch_size 2
            --per_device_eval_batch_size 1
            --remove_unused_columns False
            --overwrite_output_dir True
            --dataloader_num_workers 16
            --metric_for_best_model accuracy
369
            --max_steps 10
370
            --train_val_split 0.1
371
            --seed 42
372
373
374
375
376
377
378
379
380
        """.split()

        if is_cuda_and_apex_available():
            testargs.append("--fp16")

        with patch.object(sys, "argv", testargs):
            run_image_classification.main()
            result = get_results(tmp_dir)
            self.assertGreaterEqual(result["eval_accuracy"], 0.8)
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414

    def test_run_speech_recognition_ctc(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_speech_recognition_ctc.py
            --output_dir {tmp_dir}
            --model_name_or_path hf-internal-testing/tiny-random-wav2vec2
            --dataset_name patrickvonplaten/librispeech_asr_dummy
            --dataset_config_name clean
            --train_split_name validation
            --eval_split_name validation
            --audio_column_name file
            --do_train
            --do_eval
            --learning_rate 1e-4
            --per_device_train_batch_size 2
            --per_device_eval_batch_size 1
            --remove_unused_columns False
            --overwrite_output_dir True
            --preprocessing_num_workers 16
            --max_steps 10
            --seed 42
        """.split()

        if is_cuda_and_apex_available():
            testargs.append("--fp16")

        with patch.object(sys, "argv", testargs):
            run_speech_recognition_ctc.main()
            result = get_results(tmp_dir)
            self.assertLess(result["eval_loss"], result["train_loss"])
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449

    def test_run_audio_classification(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_audio_classification.py
            --output_dir {tmp_dir}
            --model_name_or_path hf-internal-testing/tiny-random-wav2vec2
            --dataset_name anton-l/superb_demo
            --dataset_config_name ks
            --train_split_name test
            --eval_split_name test
            --audio_column_name file
            --label_column_name label
            --do_train
            --do_eval
            --learning_rate 1e-4
            --per_device_train_batch_size 2
            --per_device_eval_batch_size 1
            --remove_unused_columns False
            --overwrite_output_dir True
            --num_train_epochs 10
            --max_steps 50
            --seed 42
        """.split()

        if is_cuda_and_apex_available():
            testargs.append("--fp16")

        with patch.object(sys, "argv", testargs):
            run_audio_classification.main()
            result = get_results(tmp_dir)
            self.assertLess(result["eval_loss"], result["train_loss"])