test_finetune_trainer.py 2.75 KB
Newer Older
Suraj Patil's avatar
Suraj Patil committed
1
2
3
4
5
6
import os
import sys
import tempfile
from unittest.mock import patch

from transformers.testing_utils import slow
Sylvain Gugger's avatar
Sylvain Gugger committed
7
8
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
Suraj Patil's avatar
Suraj Patil committed
9
10
11
12
13

from .finetune_trainer import main
from .test_seq2seq_examples import MBART_TINY


14
set_seed(42)
Suraj Patil's avatar
Suraj Patil committed
15
16
17
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"


18
19
def test_finetune_trainer():
    output_dir = run_trainer(1, "12", MBART_TINY, 1)
20
    logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
21
22
23
    eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
    first_step_stats = eval_metrics[0]
    assert "eval_bleu" in first_step_stats
Suraj Patil's avatar
Suraj Patil committed
24
25
26


@slow
27
28
def test_finetune_trainer_slow():
    # There is a missing call to __init__process_group somewhere
29
    output_dir = run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
30
31

    # Check metrics
32
    logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
    first_step_stats = eval_metrics[0]
    last_step_stats = eval_metrics[-1]

    assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"]  # model learned nothing
    assert isinstance(last_step_stats["eval_bleu"], float)

    # test if do_predict saves generations and metrics
    contents = os.listdir(output_dir)
    contents = {os.path.basename(p) for p in contents}
    assert "test_generations.txt" in contents
    assert "test_results.json" in contents


def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
Suraj Patil's avatar
Suraj Patil committed
48
    data_dir = "examples/seq2seq/test_data/wmt_en_ro"
49
    output_dir = tempfile.mkdtemp(prefix="test_output")
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    argv = f"""
        --model_name_or_path {model_name}
        --data_dir {data_dir}
        --output_dir {output_dir}
        --overwrite_output_dir
        --n_train 8
        --n_val 8
        --max_source_length {max_len}
        --max_target_length {max_len}
        --val_max_target_length {max_len}
        --do_train
        --do_eval
        --do_predict
        --num_train_epochs {str(num_train_epochs)}
        --per_device_train_batch_size 4
        --per_device_eval_batch_size 4
        --learning_rate 3e-4
        --warmup_steps 8
        --evaluate_during_training
        --predict_with_generate
        --logging_steps 0
        --save_steps {str(eval_steps)}
        --eval_steps {str(eval_steps)}
        --sortish_sampler
        --label_smoothing 0.1
        --adafactor
        --task translation
        --tgt_lang ro_RO
        --src_lang en_XX
    """.split()
    # --eval_beams  2

Suraj Patil's avatar
Suraj Patil committed
82
83
84
85
    testargs = ["finetune_trainer.py"] + argv
    with patch.object(sys, "argv", testargs):
        main()

86
    return output_dir