"docs/source/en/tasks/language_modeling.md" did not exist on "0ce5236dd11cd34585d6d3e4d05e0cd3094b3796"
test_examples.py 13.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
17

import argparse
18
import json
19
import logging
20
import os
Aymeric Augustin's avatar
Aymeric Augustin committed
21
import sys
Aymeric Augustin's avatar
Aymeric Augustin committed
22
from unittest.mock import patch
Aymeric Augustin's avatar
Aymeric Augustin committed
23

Stas Bekman's avatar
Stas Bekman committed
24
25
import torch

26
from transformers.file_utils import is_apex_available
27
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
28

29
30
31

SRC_DIRS = [
    os.path.join(os.path.dirname(__file__), dirname)
32
33
34
35
36
    for dirname in [
        "text-generation",
        "text-classification",
        "token-classification",
        "language-modeling",
37
        "multiple-choice",
38
        "question-answering",
Sylvain Gugger's avatar
Sylvain Gugger committed
39
40
        "summarization",
        "translation",
41
        "image-classification",
42
        "speech-recognition",
43
    ]
44
45
46
47
48
]
sys.path.extend(SRC_DIRS)


if SRC_DIRS is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
49
    import run_clm
50
51
    import run_generation
    import run_glue
52
    import run_image_classification
53
    import run_mlm
54
    import run_ner
Sylvain Gugger's avatar
Sylvain Gugger committed
55
    import run_qa as run_squad
56
    import run_speech_recognition_ctc
57
    import run_summarization
58
    import run_swag
59
    import run_translation
Aymeric Augustin's avatar
Aymeric Augustin committed
60

61

62
63
64
logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()
65

66

67
68
def get_setup_file():
    parser = argparse.ArgumentParser()
69
    parser.add_argument("-f")
70
71
72
73
    args = parser.parse_args()
    return args.f


74
75
76
77
78
79
80
81
82
83
84
def get_results(output_dir):
    results = {}
    path = os.path.join(output_dir, "all_results.json")
    if os.path.exists(path):
        with open(path, "r") as f:
            results = json.load(f)
    else:
        raise ValueError(f"can't find {path}")
    return results


85
def is_cuda_and_apex_available():
86
87
88
89
    is_using_cuda = torch.cuda.is_available() and torch_device == "cuda"
    return is_using_cuda and is_apex_available()


90
class ExamplesTests(TestCasePlus):
91
92
93
94
    def test_run_glue(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

95
96
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
97
            run_glue.py
98
            --model_name_or_path distilbert-base-uncased
99
100
            --output_dir {tmp_dir}
            --overwrite_output_dir
Sylvain Gugger's avatar
Sylvain Gugger committed
101
102
            --train_file ./tests/fixtures/tests_samples/MRPC/train.csv
            --validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
103
104
            --do_train
            --do_eval
105
106
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
107
108
109
110
111
            --learning_rate=1e-4
            --max_steps=10
            --warmup_steps=2
            --seed=42
            --max_seq_length=128
112
            """.split()
113

114
        if is_cuda_and_apex_available():
115
            testargs.append("--fp16")
116

117
        with patch.object(sys, "argv", testargs):
118
119
            run_glue.main()
            result = get_results(tmp_dir)
120
            self.assertGreaterEqual(result["eval_accuracy"], 0.75)
121

Sylvain Gugger's avatar
Sylvain Gugger committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    def test_run_clm(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_clm.py
            --model_name_or_path distilgpt2
            --train_file ./tests/fixtures/sample_text.txt
            --validation_file ./tests/fixtures/sample_text.txt
            --do_train
            --do_eval
            --block_size 128
            --per_device_train_batch_size 5
            --per_device_eval_batch_size 5
            --num_train_epochs 2
            --output_dir {tmp_dir}
            --overwrite_output_dir
            """.split()

        if torch.cuda.device_count() > 1:
            # Skipping because there are not enough batches to train the model + would need a drop_last to work.
            return

        if torch_device != "cuda":
            testargs.append("--no_cuda")

        with patch.object(sys, "argv", testargs):
150
151
            run_clm.main()
            result = get_results(tmp_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
152
153
            self.assertLess(result["perplexity"], 100)

154
    def test_run_mlm(self):
Julien Chaumond's avatar
Julien Chaumond committed
155
156
157
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

158
159
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
160
            run_mlm.py
Julien Chaumond's avatar
Julien Chaumond committed
161
            --model_name_or_path distilroberta-base
162
163
            --train_file ./tests/fixtures/sample_text.txt
            --validation_file ./tests/fixtures/sample_text.txt
164
            --output_dir {tmp_dir}
Julien Chaumond's avatar
Julien Chaumond committed
165
166
167
            --overwrite_output_dir
            --do_train
            --do_eval
168
            --prediction_loss_only
Julien Chaumond's avatar
Julien Chaumond committed
169
            --num_train_epochs=1
170
        """.split()
171
172
173

        if torch_device != "cuda":
            testargs.append("--no_cuda")
174

Julien Chaumond's avatar
Julien Chaumond committed
175
        with patch.object(sys, "argv", testargs):
176
177
            run_mlm.main()
            result = get_results(tmp_dir)
178
            self.assertLess(result["perplexity"], 42)
Julien Chaumond's avatar
Julien Chaumond committed
179

180
181
182
183
    def test_run_ner(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

184
185
186
        # with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
        epochs = 7 if get_gpu_count() > 1 else 2

187
188
189
190
191
192
193
194
195
196
197
198
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_ner.py
            --model_name_or_path bert-base-uncased
            --train_file tests/fixtures/tests_samples/conll/sample.json
            --validation_file tests/fixtures/tests_samples/conll/sample.json
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --do_train
            --do_eval
            --warmup_steps=2
            --learning_rate=2e-4
Sylvain Gugger's avatar
Sylvain Gugger committed
199
200
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=2
201
            --num_train_epochs={epochs}
202
            --seed 7
203
204
205
206
207
208
        """.split()

        if torch_device != "cuda":
            testargs.append("--no_cuda")

        with patch.object(sys, "argv", testargs):
209
210
            run_ner.main()
            result = get_results(tmp_dir)
211
            self.assertGreaterEqual(result["eval_accuracy"], 0.75)
212
213
            self.assertLess(result["eval_loss"], 0.5)

214
215
216
217
    def test_run_squad(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

218
219
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
Russell Klopfer's avatar
Russell Klopfer committed
220
            run_qa.py
Sylvain Gugger's avatar
Sylvain Gugger committed
221
222
223
224
            --model_name_or_path bert-base-uncased
            --version_2_with_negative
            --train_file tests/fixtures/tests_samples/SQUAD/sample.json
            --validation_file tests/fixtures/tests_samples/SQUAD/sample.json
225
226
            --output_dir {tmp_dir}
            --overwrite_output_dir
227
228
229
230
231
            --max_steps=10
            --warmup_steps=2
            --do_train
            --do_eval
            --learning_rate=2e-4
Sylvain Gugger's avatar
Sylvain Gugger committed
232
233
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
234
235
        """.split()

236
        with patch.object(sys, "argv", testargs):
237
238
            run_squad.main()
            result = get_results(tmp_dir)
Russell Klopfer's avatar
Russell Klopfer committed
239
240
            self.assertGreaterEqual(result["eval_f1"], 30)
            self.assertGreaterEqual(result["eval_exact"], 30)
241

242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    def test_run_swag(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_swag.py
            --model_name_or_path bert-base-uncased
            --train_file tests/fixtures/tests_samples/swag/sample.json
            --validation_file tests/fixtures/tests_samples/swag/sample.json
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --max_steps=20
            --warmup_steps=2
            --do_train
            --do_eval
            --learning_rate=2e-4
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
        """.split()

        with patch.object(sys, "argv", testargs):
264
265
            run_swag.main()
            result = get_results(tmp_dir)
266
267
            self.assertGreaterEqual(result["eval_accuracy"], 0.8)

268
269
270
271
    def test_generation(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

272
        testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
273

274
        if is_cuda_and_apex_available():
275
276
277
278
279
280
            testargs.append("--fp16")

        model_type, model_name = (
            "--model_type=gpt2",
            "--model_name_or_path=sshleifer/tiny-gpt2",
        )
281
        with patch.object(sys, "argv", testargs + [model_type, model_name]):
282
            result = run_generation.main()
283
            self.assertGreaterEqual(len(result[0]), 10)
284
285

    @slow
286
    def test_run_summarization(self):
287
288
289
290
291
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
292
            run_summarization.py
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
            --model_name_or_path t5-small
            --train_file tests/fixtures/tests_samples/xsum/sample.json
            --validation_file tests/fixtures/tests_samples/xsum/sample.json
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --max_steps=50
            --warmup_steps=8
            --do_train
            --do_eval
            --learning_rate=2e-4
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
            --predict_with_generate
        """.split()

        with patch.object(sys, "argv", testargs):
309
            run_summarization.main()
310
            result = get_results(tmp_dir)
311
312
313
314
315
316
            self.assertGreaterEqual(result["eval_rouge1"], 10)
            self.assertGreaterEqual(result["eval_rouge2"], 2)
            self.assertGreaterEqual(result["eval_rougeL"], 7)
            self.assertGreaterEqual(result["eval_rougeLsum"], 7)

    @slow
317
    def test_run_translation(self):
318
319
320
321
322
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
323
            run_translation.py
324
            --model_name_or_path sshleifer/student_marian_en_ro_6_1
325
326
            --source_lang en
            --target_lang ro
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
            --train_file tests/fixtures/tests_samples/wmt16/sample.json
            --validation_file tests/fixtures/tests_samples/wmt16/sample.json
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --max_steps=50
            --warmup_steps=8
            --do_train
            --do_eval
            --learning_rate=3e-3
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
            --predict_with_generate
            --source_lang en_XX
            --target_lang ro_RO
        """.split()

        with patch.object(sys, "argv", testargs):
344
            run_translation.main()
345
            result = get_results(tmp_dir)
346
            self.assertGreaterEqual(result["eval_bleu"], 30)
347
348
349
350
351
352
353
354
355
356

    def test_run_image_classification(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_image_classification.py
            --output_dir {tmp_dir}
            --model_name_or_path google/vit-base-patch16-224-in21k
357
            --dataset_name hf-internal-testing/cats_vs_dogs_sample
358
359
            --do_train
            --do_eval
360
            --learning_rate 1e-4
361
362
363
364
365
366
            --per_device_train_batch_size 2
            --per_device_eval_batch_size 1
            --remove_unused_columns False
            --overwrite_output_dir True
            --dataloader_num_workers 16
            --metric_for_best_model accuracy
367
            --max_steps 10
368
            --train_val_split 0.1
369
            --seed 42
370
371
372
373
374
375
376
377
378
        """.split()

        if is_cuda_and_apex_available():
            testargs.append("--fp16")

        with patch.object(sys, "argv", testargs):
            run_image_classification.main()
            result = get_results(tmp_dir)
            self.assertGreaterEqual(result["eval_accuracy"], 0.8)
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412

    def test_run_speech_recognition_ctc(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_speech_recognition_ctc.py
            --output_dir {tmp_dir}
            --model_name_or_path hf-internal-testing/tiny-random-wav2vec2
            --dataset_name patrickvonplaten/librispeech_asr_dummy
            --dataset_config_name clean
            --train_split_name validation
            --eval_split_name validation
            --audio_column_name file
            --do_train
            --do_eval
            --learning_rate 1e-4
            --per_device_train_batch_size 2
            --per_device_eval_batch_size 1
            --remove_unused_columns False
            --overwrite_output_dir True
            --preprocessing_num_workers 16
            --max_steps 10
            --seed 42
        """.split()

        if is_cuda_and_apex_available():
            testargs.append("--fp16")

        with patch.object(sys, "argv", testargs):
            run_speech_recognition_ctc.main()
            result = get_results(tmp_dir)
            self.assertLess(result["eval_loss"], result["train_loss"])