training_stsbenchmark_continue_training.py 4 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
import os
import math
Rayyyyy's avatar
Rayyyyy committed
3
4
import gzip
import csv
Rayyyyy's avatar
Rayyyyy committed
5
6
7
8
9
import logging
import argparse

from datetime import datetime
from torch.utils.data import DataLoader
Rayyyyy's avatar
Rayyyyy committed
10
from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses, util, InputExample
Rayyyyy's avatar
Rayyyyy committed
11
12
13
14
15
16
17
18
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator

#### Just some code to print debug information to stdout
logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)

parser = argparse.ArgumentParser()
Rayyyyy's avatar
Rayyyyy committed
19
parser.add_argument('--data_path', type=str, default='datasets/stsbenchmark.tsv.gz', help='Input txt path')
Rayyyyy's avatar
Rayyyyy committed
20
21
22
parser.add_argument('--train_batch_size', type=int, default=16)
parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--model_name_or_path', type=str, default="all-MiniLM-L6-v2")
Rayyyyy's avatar
Rayyyyy committed
23
parser.add_argument('--save_root_path', type=str, default="output", help='Model output folder')
Rayyyyy's avatar
Rayyyyy committed
24
25
26
27
28
29
30
parser.add_argument('--lr', default=2e-05)
args = parser.parse_args()


if __name__ == "__main__":

    sts_dataset_path = args.data_path
Rayyyyy's avatar
Rayyyyy committed
31

Rayyyyy's avatar
Rayyyyy committed
32
33
    # Check if dataset exists. If not, download and extract it
    if not os.path.exists(sts_dataset_path):
Rayyyyy's avatar
Rayyyyy committed
34
        util.http_get('https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/stsbenchmark.tsv.gz', sts_dataset_path)
Rayyyyy's avatar
Rayyyyy committed
35

Rayyyyy's avatar
Rayyyyy committed
36
    model_name = args.model_name_or_path
Rayyyyy's avatar
Rayyyyy committed
37
38
    train_batch_size = args.train_batch_size
    num_epochs = args.num_epochs
Rayyyyy's avatar
Rayyyyy committed
39
    model_save_path = args.save_root_path + "/training_stsbenchmark_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
Rayyyyy's avatar
Rayyyyy committed
40
41

    # Load a pre-trained sentence transformer model
Rayyyyy's avatar
Rayyyyy committed
42
    model = SentenceTransformer(model_name, device='cuda')
Rayyyyy's avatar
Rayyyyy committed
43
44
45
46
47
48

    # Convert the dataset to a DataLoader ready for training
    logging.info("Read STSbenchmark train dataset")
    # Read the dataset
    train_samples = []
    dev_samples = []
Rayyyyy's avatar
Rayyyyy committed
49
50
51
52
53
54
55
56
    test_samples = []
    with gzip.open(sts_dataset_path, 'rt', encoding='utf8') as fIn:
        reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
        for row in reader:
            score = float(row['score']) / 5.0  # Normalize score to range 0 ... 1
            inp_example = InputExample(texts=[row['sentence1'], row['sentence2']], label=score)

            if row['split'] == 'dev':
Rayyyyy's avatar
Rayyyyy committed
57
                dev_samples.append(inp_example)
Rayyyyy's avatar
Rayyyyy committed
58
59
            elif row['split'] == 'test':
                test_samples.append(inp_example)
Rayyyyy's avatar
Rayyyyy committed
60
61
62
63
            else:
                train_samples.append(inp_example)

    logging.info("Dealing data end.")
Rayyyyy's avatar
Rayyyyy committed
64
65
    train_dataset = SentencesDataset(train_samples, model)
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
Rayyyyy's avatar
Rayyyyy committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    train_loss = losses.CosineSimilarityLoss(model=model)

    # Development set: Measure correlation between cosine score and gold labels
    logging.info("Read STSbenchmark dev dataset")
    evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev")

    # Configure the training. We skip evaluation in this example
    warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1)  # 10% of train data for warm-up
    logging.info("Warmup-steps: {}".format(warmup_steps))

    print("Start training ...")
    # Train the model
    model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        evaluator=evaluator,
        epochs=num_epochs,
        evaluation_steps=1000,
        warmup_steps=warmup_steps,
        optimizer_params={'lr': args.lr},
        output_path=model_save_path,
    )
    logging.info("Finetune end")

    ##############################################################################
    #
    # Load the stored model and evaluate its performance on STS benchmark dataset
    #
    ##############################################################################

    model = SentenceTransformer(model_save_path)
Rayyyyy's avatar
Rayyyyy committed
96
    test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='sts-test')
Rayyyyy's avatar
Rayyyyy committed
97
    test_evaluator(model, output_path=model_save_path)