test_finetune_trainer.py 3.04 KB
Newer Older
Suraj Patil's avatar
Suraj Patil committed
1
2
3
4
5
6
import os
import sys
import tempfile
from unittest.mock import patch

from transformers.testing_utils import slow
7
from transformers.trainer_utils import TrainerState, set_seed
Suraj Patil's avatar
Suraj Patil committed
8
9
10
11
12

from .finetune_trainer import main
from .test_seq2seq_examples import MBART_TINY


13
set_seed(42)
Suraj Patil's avatar
Suraj Patil committed
14
15
16
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"


17
18
def test_finetune_trainer():
    output_dir = run_trainer(1, "12", MBART_TINY, 1)
19
    logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
20
21
22
    eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
    first_step_stats = eval_metrics[0]
    assert "eval_bleu" in first_step_stats
Suraj Patil's avatar
Suraj Patil committed
23
24
25


@slow
26
27
28
def test_finetune_trainer_slow():
    # TODO(SS): This will fail on devices with more than 1 GPU.
    # There is a missing call to __init__process_group somewhere
29
    output_dir = run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
30
31

    # Check metrics
32
    logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
    first_step_stats = eval_metrics[0]
    last_step_stats = eval_metrics[-1]

    assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"]  # model learned nothing
    assert isinstance(last_step_stats["eval_bleu"], float)

    # test if do_predict saves generations and metrics
    contents = os.listdir(output_dir)
    contents = {os.path.basename(p) for p in contents}
    assert "test_generations.txt" in contents
    assert "test_results.json" in contents


def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
Suraj Patil's avatar
Suraj Patil committed
48
    data_dir = "examples/seq2seq/test_data/wmt_en_ro"
49
    output_dir = tempfile.mkdtemp(prefix="test_output")
Suraj Patil's avatar
Suraj Patil committed
50
51
    argv = [
        "--model_name_or_path",
52
        model_name,
Suraj Patil's avatar
Suraj Patil committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        "--data_dir",
        data_dir,
        "--output_dir",
        output_dir,
        "--overwrite_output_dir",
        "--n_train",
        "8",
        "--n_val",
        "8",
        "--max_source_length",
        max_len,
        "--max_target_length",
        max_len,
        "--val_max_target_length",
        max_len,
        "--do_train",
        "--do_eval",
        "--do_predict",
        "--num_train_epochs",
        str(num_train_epochs),
        "--per_device_train_batch_size",
        "4",
        "--per_device_eval_batch_size",
        "4",
        "--learning_rate",
        "3e-4",
        "--warmup_steps",
        "8",
        "--evaluate_during_training",
        "--predict_with_generate",
        "--logging_steps",
        0,
        "--save_steps",
        str(eval_steps),
        "--eval_steps",
        str(eval_steps),
        "--sortish_sampler",
        "--label_smoothing",
        "0.1",
92
93
        # "--eval_beams",
        # "2",
94
        "--adafactor",
Suraj Patil's avatar
Suraj Patil committed
95
96
        "--task",
        "translation",
97
98
99
100
        "--tgt_lang",
        "ro_RO",
        "--src_lang",
        "en_XX",
Suraj Patil's avatar
Suraj Patil committed
101
102
103
104
105
    ]
    testargs = ["finetune_trainer.py"] + argv
    with patch.object(sys, "argv", testargs):
        main()

106
    return output_dir