train.py 2.66 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse
import json
from pathlib import Path

from torch.utils.data import DataLoader
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    DataCollatorForTokenClassification,
)

from lettucedetect.datasets.ragtruth import RagTruthDataset
from lettucedetect.models.trainer import Trainer
from lettucedetect.preprocess.preprocess_ragtruth import RagTruthData


def parse_args():
    parser = argparse.ArgumentParser(description="Train hallucination detector model")
    parser.add_argument(
        "--data-path",
        type=str,
        default="data/ragtruth/ragtruth_data.json",
        help="Path to the training data JSON file",
    )
    parser.add_argument(
        "--model-name",
        type=str,
        default="answerdotai/ModernBERT-base",
        help="Name or path of the pretrained model",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="output/hallucination_detector",
        help="Directory to save the trained model",
    )
    parser.add_argument(
        "--batch-size", type=int, default=4, help="Batch size for training and testing"
    )
    parser.add_argument("--epochs", type=int, default=6, help="Number of training epochs")
    parser.add_argument(
        "--learning-rate", type=float, default=1e-5, help="Learning rate for training"
    )
    return parser.parse_args()


def main():
    args = parse_args()
    data_path = Path(args.data_path)
    rag_truth_data = RagTruthData.from_json(json.loads(data_path.read_text()))

    train_samples = [sample for sample in rag_truth_data.samples if sample.split == "train"]
    test_samples = [sample for sample in rag_truth_data.samples if sample.split == "test"]

    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, label_pad_token_id=-100)

    train_dataset = RagTruthDataset(train_samples, tokenizer)
    test_dataset = RagTruthDataset(test_samples, tokenizer)

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=data_collator,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=data_collator,
    )

    model = AutoModelForTokenClassification.from_pretrained(args.model_name, num_labels=2)

    trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        train_loader=train_loader,
        test_loader=test_loader,
        epochs=args.epochs,
        learning_rate=args.learning_rate,
        save_path=args.output_dir,
    )

    trainer.train()


if __name__ == "__main__":
    main()