test_finetune_trainer.py 2.97 KB
Newer Older
Suraj Patil's avatar
Suraj Patil committed
1
2
3
4
5
6
import os
import sys
import tempfile
from unittest.mock import patch

from transformers.testing_utils import slow
7
from transformers.trainer_utils import set_seed
Suraj Patil's avatar
Suraj Patil committed
8
9
10
11
12
13

from .finetune_trainer import main
from .test_seq2seq_examples import MBART_TINY
from .utils import load_json


14
set_seed(42)
Suraj Patil's avatar
Suraj Patil committed
15
16
17
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"


18
19
20
21
22
23
def test_finetune_trainer():
    output_dir = run_trainer(1, "12", MBART_TINY, 1)
    logs = load_json(os.path.join(output_dir, "log_history.json"))
    eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
    first_step_stats = eval_metrics[0]
    assert "eval_bleu" in first_step_stats
Suraj Patil's avatar
Suraj Patil committed
24
25
26


@slow
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def test_finetune_trainer_slow():
    # TODO(SS): This will fail on devices with more than 1 GPU.
    # There is a missing call to __init__process_group somewhere
    output_dir = run_trainer(eval_steps=2, max_len="32", model_name=MARIAN_MODEL, num_train_epochs=3)

    # Check metrics
    logs = load_json(os.path.join(output_dir, "log_history.json"))
    eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
    first_step_stats = eval_metrics[0]
    last_step_stats = eval_metrics[-1]

    assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"]  # model learned nothing
    assert isinstance(last_step_stats["eval_bleu"], float)

    # test if do_predict saves generations and metrics
    contents = os.listdir(output_dir)
    contents = {os.path.basename(p) for p in contents}
    assert "test_generations.txt" in contents
    assert "test_results.json" in contents


def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
Suraj Patil's avatar
Suraj Patil committed
49
    data_dir = "examples/seq2seq/test_data/wmt_en_ro"
50
    output_dir = tempfile.mkdtemp(prefix="test_output")
Suraj Patil's avatar
Suraj Patil committed
51
52
    argv = [
        "--model_name_or_path",
53
        model_name,
Suraj Patil's avatar
Suraj Patil committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        "--data_dir",
        data_dir,
        "--output_dir",
        output_dir,
        "--overwrite_output_dir",
        "--n_train",
        "8",
        "--n_val",
        "8",
        "--max_source_length",
        max_len,
        "--max_target_length",
        max_len,
        "--val_max_target_length",
        max_len,
        "--do_train",
        "--do_eval",
        "--do_predict",
        "--num_train_epochs",
        str(num_train_epochs),
        "--per_device_train_batch_size",
        "4",
        "--per_device_eval_batch_size",
        "4",
        "--learning_rate",
        "3e-4",
        "--warmup_steps",
        "8",
        "--evaluate_during_training",
        "--predict_with_generate",
        "--logging_steps",
        0,
        "--save_steps",
        str(eval_steps),
        "--eval_steps",
        str(eval_steps),
        "--sortish_sampler",
        "--label_smoothing",
        "0.1",
93
94
        # "--eval_beams",
        # "2",
Suraj Patil's avatar
Suraj Patil committed
95
96
        "--task",
        "translation",
97
98
99
100
        "--tgt_lang",
        "ro_RO",
        "--src_lang",
        "en_XX",
Suraj Patil's avatar
Suraj Patil committed
101
102
103
104
105
    ]
    testargs = ["finetune_trainer.py"] + argv
    with patch.object(sys, "argv", testargs):
        main()

106
    return output_dir