test_trainer.py 5.56 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import re
import tempfile
from pathlib import Path

import pytest

from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
from sentence_transformers.util import is_datasets_available, is_training_available

if is_datasets_available():
    from datasets import DatasetDict


@pytest.mark.skipif(
    not is_training_available(),
    reason='Sentence Transformers was not installed with the `["train"]` extra.',
)
def test_trainer_multi_dataset_errors(
    stsb_bert_tiny_model: SentenceTransformer, stsb_dataset_dict: "DatasetDict"
) -> None:
    train_dataset = stsb_dataset_dict["train"]
    loss = {
        "multi_nli": losses.CosineSimilarityLoss(model=stsb_bert_tiny_model),
        "snli": losses.CosineSimilarityLoss(model=stsb_bert_tiny_model),
        "stsb": losses.CosineSimilarityLoss(model=stsb_bert_tiny_model),
    }
    with pytest.raises(
        ValueError, match="If the provided `loss` is a dict, then the `train_dataset` must be a `DatasetDict`."
    ):
        SentenceTransformerTrainer(model=stsb_bert_tiny_model, train_dataset=train_dataset, loss=loss)

    train_dataset = DatasetDict(
        {
            "multi_nli": stsb_dataset_dict["train"],
            "snli": stsb_dataset_dict["train"],
            "stsb": stsb_dataset_dict["train"],
            "stsb-extra": stsb_dataset_dict["train"],
        }
    )
    with pytest.raises(
        ValueError,
        match="If the provided `loss` is a dict, then all keys from the `train_dataset` dictionary must occur in `loss` also. "
        "Currently, \['stsb-extra'\] occurs in `train_dataset` but not in `loss`.",
    ):
        SentenceTransformerTrainer(model=stsb_bert_tiny_model, train_dataset=train_dataset, loss=loss)

    train_dataset = DatasetDict(
        {
            "multi_nli": stsb_dataset_dict["train"],
            "snli": stsb_dataset_dict["train"],
            "stsb": stsb_dataset_dict["train"],
        }
    )
    with pytest.raises(
        ValueError, match="If the provided `loss` is a dict, then the `eval_dataset` must be a `DatasetDict`."
    ):
        SentenceTransformerTrainer(
            model=stsb_bert_tiny_model,
            train_dataset=train_dataset,
            eval_dataset=stsb_dataset_dict["validation"],
            loss=loss,
        )

    eval_dataset = DatasetDict(
        {
            "multi_nli": stsb_dataset_dict["validation"],
            "snli": stsb_dataset_dict["validation"],
            "stsb": stsb_dataset_dict["validation"],
            "stsb-extra-1": stsb_dataset_dict["validation"],
            "stsb-extra-2": stsb_dataset_dict["validation"],
        }
    )
    with pytest.raises(
        ValueError,
        match="If the provided `loss` is a dict, then all keys from the `eval_dataset` dictionary must occur in `loss` also. "
        "Currently, \['stsb-extra-1', 'stsb-extra-2'\] occur in `eval_dataset` but not in `loss`.",
    ):
        SentenceTransformerTrainer(
            model=stsb_bert_tiny_model, train_dataset=train_dataset, eval_dataset=eval_dataset, loss=loss
        )


@pytest.mark.skipif(
    not is_training_available(),
    reason='Sentence Transformers was not installed with the `["train"]` extra.',
)
def test_trainer_invalid_column_names(
    stsb_bert_tiny_model: SentenceTransformer, stsb_dataset_dict: "DatasetDict"
) -> None:
    train_dataset = stsb_dataset_dict["train"]
    for column_name in ("return_loss", "dataset_name"):
        invalid_train_dataset = train_dataset.rename_column("sentence1", column_name)
        trainer = SentenceTransformerTrainer(model=stsb_bert_tiny_model, train_dataset=invalid_train_dataset)
        with pytest.raises(
            ValueError,
            match=re.escape(
                f"The following column names are invalid in your dataset: ['{column_name}']."
                " Avoid using these column names, as they are reserved for internal use."
            ),
        ):
            trainer.train()

        invalid_train_dataset = DatasetDict(
            {
                "stsb": train_dataset.rename_column("sentence1", column_name),
                "stsb-2": train_dataset,
            }
        )
        trainer = SentenceTransformerTrainer(model=stsb_bert_tiny_model, train_dataset=invalid_train_dataset)
        with pytest.raises(
            ValueError,
            match=re.escape(
                f"The following column names are invalid in your stsb dataset: ['{column_name}']."
                " Avoid using these column names, as they are reserved for internal use."
            ),
        ):
            trainer.train()


@pytest.mark.skipif(
    not is_training_available(),
    reason='Sentence Transformers was not installed with the `["train"]` extra.',
)
def test_model_card_reuse(stsb_bert_tiny_model: SentenceTransformer):
    assert stsb_bert_tiny_model._model_card_text
    # Reuse the model card if no training was done
    with tempfile.TemporaryDirectory() as tmp_folder:
        model_path = Path(tmp_folder) / "tiny_model_local"
        stsb_bert_tiny_model.save(str(model_path))

        with open(model_path / "README.md", "r") as f:
            model_card_text = f.read()
        assert model_card_text == stsb_bert_tiny_model._model_card_text

    # Create a new model card if a Trainer was initialized
    SentenceTransformerTrainer(model=stsb_bert_tiny_model)

    with tempfile.TemporaryDirectory() as tmp_folder:
        model_path = Path(tmp_folder) / "tiny_model_local"
        stsb_bert_tiny_model.save(str(model_path))

        with open(model_path / "README.md", "r") as f:
            model_card_text = f.read()
        assert model_card_text != stsb_bert_tiny_model._model_card_text