test_examples.py 12.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
17

import argparse
18
import json
19
import logging
20
import os
Aymeric Augustin's avatar
Aymeric Augustin committed
21
import sys
Aymeric Augustin's avatar
Aymeric Augustin committed
22
from unittest.mock import patch
Aymeric Augustin's avatar
Aymeric Augustin committed
23

Stas Bekman's avatar
Stas Bekman committed
24
25
import torch

26
from transformers.file_utils import is_apex_available
27
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
28

29
30
31

SRC_DIRS = [
    os.path.join(os.path.dirname(__file__), dirname)
32
33
34
35
36
    for dirname in [
        "text-generation",
        "text-classification",
        "token-classification",
        "language-modeling",
37
        "multiple-choice",
38
        "question-answering",
Sylvain Gugger's avatar
Sylvain Gugger committed
39
40
        "summarization",
        "translation",
41
        "image-classification",
42
    ]
43
44
45
46
47
]
sys.path.extend(SRC_DIRS)


if SRC_DIRS is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
48
    import run_clm
49
50
    import run_generation
    import run_glue
51
    import run_image_classification
52
    import run_mlm
53
    import run_ner
Sylvain Gugger's avatar
Sylvain Gugger committed
54
    import run_qa as run_squad
55
    import run_summarization
56
    import run_swag
57
    import run_translation
Aymeric Augustin's avatar
Aymeric Augustin committed
58

59

60
61
62
logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()
63

64

65
66
def get_setup_file():
    parser = argparse.ArgumentParser()
67
    parser.add_argument("-f")
68
69
70
71
    args = parser.parse_args()
    return args.f


72
73
74
75
76
77
78
79
80
81
82
def get_results(output_dir):
    results = {}
    path = os.path.join(output_dir, "all_results.json")
    if os.path.exists(path):
        with open(path, "r") as f:
            results = json.load(f)
    else:
        raise ValueError(f"can't find {path}")
    return results


83
def is_cuda_and_apex_available():
84
85
86
87
    is_using_cuda = torch.cuda.is_available() and torch_device == "cuda"
    return is_using_cuda and is_apex_available()


88
class ExamplesTests(TestCasePlus):
89
90
91
92
    def test_run_glue(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

93
94
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
95
            run_glue.py
96
            --model_name_or_path distilbert-base-uncased
97
98
            --output_dir {tmp_dir}
            --overwrite_output_dir
Sylvain Gugger's avatar
Sylvain Gugger committed
99
100
            --train_file ./tests/fixtures/tests_samples/MRPC/train.csv
            --validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
101
102
            --do_train
            --do_eval
103
104
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
105
106
107
108
109
            --learning_rate=1e-4
            --max_steps=10
            --warmup_steps=2
            --seed=42
            --max_seq_length=128
110
            """.split()
111

112
        if is_cuda_and_apex_available():
113
            testargs.append("--fp16")
114

115
        with patch.object(sys, "argv", testargs):
116
117
            run_glue.main()
            result = get_results(tmp_dir)
118
            self.assertGreaterEqual(result["eval_accuracy"], 0.75)
119

Sylvain Gugger's avatar
Sylvain Gugger committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    def test_run_clm(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_clm.py
            --model_name_or_path distilgpt2
            --train_file ./tests/fixtures/sample_text.txt
            --validation_file ./tests/fixtures/sample_text.txt
            --do_train
            --do_eval
            --block_size 128
            --per_device_train_batch_size 5
            --per_device_eval_batch_size 5
            --num_train_epochs 2
            --output_dir {tmp_dir}
            --overwrite_output_dir
            """.split()

        if torch.cuda.device_count() > 1:
            # Skipping because there are not enough batches to train the model + would need a drop_last to work.
            return

        if torch_device != "cuda":
            testargs.append("--no_cuda")

        with patch.object(sys, "argv", testargs):
148
149
            run_clm.main()
            result = get_results(tmp_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
150
151
            self.assertLess(result["perplexity"], 100)

152
    def test_run_mlm(self):
Julien Chaumond's avatar
Julien Chaumond committed
153
154
155
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

156
157
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
158
            run_mlm.py
Julien Chaumond's avatar
Julien Chaumond committed
159
            --model_name_or_path distilroberta-base
160
161
            --train_file ./tests/fixtures/sample_text.txt
            --validation_file ./tests/fixtures/sample_text.txt
162
            --output_dir {tmp_dir}
Julien Chaumond's avatar
Julien Chaumond committed
163
164
165
            --overwrite_output_dir
            --do_train
            --do_eval
166
            --prediction_loss_only
Julien Chaumond's avatar
Julien Chaumond committed
167
            --num_train_epochs=1
168
        """.split()
169
170
171

        if torch_device != "cuda":
            testargs.append("--no_cuda")
172

Julien Chaumond's avatar
Julien Chaumond committed
173
        with patch.object(sys, "argv", testargs):
174
175
            run_mlm.main()
            result = get_results(tmp_dir)
176
            self.assertLess(result["perplexity"], 42)
Julien Chaumond's avatar
Julien Chaumond committed
177

178
179
180
181
    def test_run_ner(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

182
183
184
        # with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
        epochs = 7 if get_gpu_count() > 1 else 2

185
186
187
188
189
190
191
192
193
194
195
196
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_ner.py
            --model_name_or_path bert-base-uncased
            --train_file tests/fixtures/tests_samples/conll/sample.json
            --validation_file tests/fixtures/tests_samples/conll/sample.json
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --do_train
            --do_eval
            --warmup_steps=2
            --learning_rate=2e-4
Sylvain Gugger's avatar
Sylvain Gugger committed
197
198
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=2
199
            --num_train_epochs={epochs}
200
            --seed 7
201
202
203
204
205
206
        """.split()

        if torch_device != "cuda":
            testargs.append("--no_cuda")

        with patch.object(sys, "argv", testargs):
207
208
            run_ner.main()
            result = get_results(tmp_dir)
209
            self.assertGreaterEqual(result["eval_accuracy"], 0.75)
210
211
            self.assertLess(result["eval_loss"], 0.5)

212
213
214
215
    def test_run_squad(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

216
217
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
Russell Klopfer's avatar
Russell Klopfer committed
218
            run_qa.py
Sylvain Gugger's avatar
Sylvain Gugger committed
219
220
221
222
            --model_name_or_path bert-base-uncased
            --version_2_with_negative
            --train_file tests/fixtures/tests_samples/SQUAD/sample.json
            --validation_file tests/fixtures/tests_samples/SQUAD/sample.json
223
224
            --output_dir {tmp_dir}
            --overwrite_output_dir
225
226
227
228
229
            --max_steps=10
            --warmup_steps=2
            --do_train
            --do_eval
            --learning_rate=2e-4
Sylvain Gugger's avatar
Sylvain Gugger committed
230
231
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
232
233
        """.split()

234
        with patch.object(sys, "argv", testargs):
235
236
            run_squad.main()
            result = get_results(tmp_dir)
Russell Klopfer's avatar
Russell Klopfer committed
237
238
            self.assertGreaterEqual(result["eval_f1"], 30)
            self.assertGreaterEqual(result["eval_exact"], 30)
239

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    def test_run_swag(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_swag.py
            --model_name_or_path bert-base-uncased
            --train_file tests/fixtures/tests_samples/swag/sample.json
            --validation_file tests/fixtures/tests_samples/swag/sample.json
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --max_steps=20
            --warmup_steps=2
            --do_train
            --do_eval
            --learning_rate=2e-4
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
        """.split()

        with patch.object(sys, "argv", testargs):
262
263
            run_swag.main()
            result = get_results(tmp_dir)
264
265
            self.assertGreaterEqual(result["eval_accuracy"], 0.8)

266
267
268
269
    def test_generation(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

270
        testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
271

272
        if is_cuda_and_apex_available():
273
274
275
276
277
278
            testargs.append("--fp16")

        model_type, model_name = (
            "--model_type=gpt2",
            "--model_name_or_path=sshleifer/tiny-gpt2",
        )
279
        with patch.object(sys, "argv", testargs + [model_type, model_name]):
280
            result = run_generation.main()
281
            self.assertGreaterEqual(len(result[0]), 10)
282
283

    @slow
284
    def test_run_summarization(self):
285
286
287
288
289
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
290
            run_summarization.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
            --model_name_or_path t5-small
            --train_file tests/fixtures/tests_samples/xsum/sample.json
            --validation_file tests/fixtures/tests_samples/xsum/sample.json
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --max_steps=50
            --warmup_steps=8
            --do_train
            --do_eval
            --learning_rate=2e-4
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
            --predict_with_generate
        """.split()

        with patch.object(sys, "argv", testargs):
307
            run_summarization.main()
308
            result = get_results(tmp_dir)
309
310
311
312
313
314
            self.assertGreaterEqual(result["eval_rouge1"], 10)
            self.assertGreaterEqual(result["eval_rouge2"], 2)
            self.assertGreaterEqual(result["eval_rougeL"], 7)
            self.assertGreaterEqual(result["eval_rougeLsum"], 7)

    @slow
315
    def test_run_translation(self):
316
317
318
319
320
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
321
            run_translation.py
322
            --model_name_or_path sshleifer/student_marian_en_ro_6_1
323
324
            --source_lang en
            --target_lang ro
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
            --train_file tests/fixtures/tests_samples/wmt16/sample.json
            --validation_file tests/fixtures/tests_samples/wmt16/sample.json
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --max_steps=50
            --warmup_steps=8
            --do_train
            --do_eval
            --learning_rate=3e-3
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
            --predict_with_generate
            --source_lang en_XX
            --target_lang ro_RO
        """.split()

        with patch.object(sys, "argv", testargs):
342
            run_translation.main()
343
            result = get_results(tmp_dir)
344
            self.assertGreaterEqual(result["eval_bleu"], 30)
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376

    def test_run_image_classification(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_image_classification.py
            --output_dir {tmp_dir}
            --model_name_or_path google/vit-base-patch16-224-in21k
            --train_dir tests/fixtures/tests_samples/cats_and_dogs/
            --do_train
            --do_eval
            --learning_rate 2e-5
            --per_device_train_batch_size 2
            --per_device_eval_batch_size 1
            --remove_unused_columns False
            --overwrite_output_dir True
            --dataloader_num_workers 16
            --metric_for_best_model accuracy
            --max_steps 30
            --train_val_split 0.1
            --seed 7
        """.split()

        if is_cuda_and_apex_available():
            testargs.append("--fp16")

        with patch.object(sys, "argv", testargs):
            run_image_classification.main()
            result = get_results(tmp_dir)
            self.assertGreaterEqual(result["eval_accuracy"], 0.8)