test_finetune_trainer.py 3.08 KB
Newer Older
Suraj Patil's avatar
Suraj Patil committed
1
2
3
4
5
6
import os
import sys
import tempfile
from unittest.mock import patch

from transformers.testing_utils import slow
Sylvain Gugger's avatar
Sylvain Gugger committed
7
8
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
Suraj Patil's avatar
Suraj Patil committed
9
10
11
12
13

from .finetune_trainer import main
from .test_seq2seq_examples import MBART_TINY


14
set_seed(42)
Suraj Patil's avatar
Suraj Patil committed
15
16
17
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"


18
19
def test_finetune_trainer():
    output_dir = run_trainer(1, "12", MBART_TINY, 1)
20
    logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
21
22
23
    eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
    first_step_stats = eval_metrics[0]
    assert "eval_bleu" in first_step_stats
Suraj Patil's avatar
Suraj Patil committed
24
25
26


@slow
27
28
29
def test_finetune_trainer_slow():
    # TODO(SS): This will fail on devices with more than 1 GPU.
    # There is a missing call to __init__process_group somewhere
30
    output_dir = run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
31
32

    # Check metrics
33
    logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
    first_step_stats = eval_metrics[0]
    last_step_stats = eval_metrics[-1]

    assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"]  # model learned nothing
    assert isinstance(last_step_stats["eval_bleu"], float)

    # test if do_predict saves generations and metrics
    contents = os.listdir(output_dir)
    contents = {os.path.basename(p) for p in contents}
    assert "test_generations.txt" in contents
    assert "test_results.json" in contents


def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
Suraj Patil's avatar
Suraj Patil committed
49
    data_dir = "examples/seq2seq/test_data/wmt_en_ro"
50
    output_dir = tempfile.mkdtemp(prefix="test_output")
Suraj Patil's avatar
Suraj Patil committed
51
52
    argv = [
        "--model_name_or_path",
53
        model_name,
Suraj Patil's avatar
Suraj Patil committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        "--data_dir",
        data_dir,
        "--output_dir",
        output_dir,
        "--overwrite_output_dir",
        "--n_train",
        "8",
        "--n_val",
        "8",
        "--max_source_length",
        max_len,
        "--max_target_length",
        max_len,
        "--val_max_target_length",
        max_len,
        "--do_train",
        "--do_eval",
        "--do_predict",
        "--num_train_epochs",
        str(num_train_epochs),
        "--per_device_train_batch_size",
        "4",
        "--per_device_eval_batch_size",
        "4",
        "--learning_rate",
        "3e-4",
        "--warmup_steps",
        "8",
        "--evaluate_during_training",
        "--predict_with_generate",
        "--logging_steps",
        0,
        "--save_steps",
        str(eval_steps),
        "--eval_steps",
        str(eval_steps),
        "--sortish_sampler",
        "--label_smoothing",
        "0.1",
93
94
        # "--eval_beams",
        # "2",
95
        "--adafactor",
Suraj Patil's avatar
Suraj Patil committed
96
97
        "--task",
        "translation",
98
99
100
101
        "--tgt_lang",
        "ro_RO",
        "--src_lang",
        "en_XX",
Suraj Patil's avatar
Suraj Patil committed
102
103
104
105
106
    ]
    testargs = ["finetune_trainer.py"] + argv
    with patch.object(sys, "argv", testargs):
        main()

107
    return output_dir