evaluator.py 9.16 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import torch
from sklearn.metrics import classification_report, precision_recall_fscore_support
from torch.nn import Module
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from lettucedetect.datasets.ragtruth import RagTruthSample
from lettucedetect.models.inference import HallucinationDetector


def evaluate_model(
    model: Module,
    dataloader: DataLoader,
    device: torch.device,
    verbose: bool = True,
) -> dict[str, dict[str, float]]:
    """Evaluate a model for hallucination detection.

    :param model: The model to evaluate.
    :param dataloader: The data loader to use for evaluation.
    :param device: The device to use for evaluation.
    :param verbose: If True, print the evaluation metrics.
    :return: A dictionary containing the evaluation metrics.
        {
            "supported": {"precision": float, "recall": float, "f1": float},
            "hallucinated": {"precision": float, "recall": float, "f1": float}
        }
    """
    model.eval()
    all_preds: list[int] = []
    all_labels: list[int] = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            outputs = model(
                batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
            )
            logits: torch.Tensor = outputs.logits
            predictions = torch.argmax(logits, dim=-1)

            # Only evaluate on tokens that have labels (not -100)
            mask = batch["labels"] != -100
            predictions = predictions[mask].cpu().numpy()
            labels = batch["labels"][mask].cpu().numpy()

            all_preds.extend(predictions.tolist())
            all_labels.extend(labels.tolist())

    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, labels=[0, 1], average=None
    )

    results: dict[str, dict[str, float]] = {
        "supported": {  # Class 0
            "precision": float(precision[0]),
            "recall": float(recall[0]),
            "f1": float(f1[0]),
        },
        "hallucinated": {  # Class 1
            "precision": float(precision[1]),
            "recall": float(recall[1]),
            "f1": float(f1[1]),
        },
    }

    if verbose:
        report = classification_report(
            all_labels, all_preds, target_names=["Supported", "Hallucinated"], digits=4
        )
        print("\nDetailed Classification Report:")
        print(report)
        results["classification_report"] = report

    return results


def print_metrics(metrics: dict[str, dict[str, float]]) -> None:
    """Print evaluation metrics in a readable format.

    :param metrics: A dictionary containing the evaluation metrics.
    :return: None
    """
    print("\nEvaluation Results:")
    print("\nHallucination Detection (Class 1):")
    print(f"  Precision: {metrics['hallucinated']['precision']:.4f}")
    print(f"  Recall: {metrics['hallucinated']['recall']:.4f}")
    print(f"  F1: {metrics['hallucinated']['f1']:.4f}")

    print("\nSupported Content (Class 0):")
    print(f"  Precision: {metrics['supported']['precision']:.4f}")
    print(f"  Recall: {metrics['supported']['recall']:.4f}")
    print(f"  F1: {metrics['supported']['f1']:.4f}")


def evaluate_model_example_level(
    model: Module,
    dataloader: DataLoader,
    device: torch.device,
    verbose: bool = True,
) -> dict[str, dict[str, float]]:
    """Evaluate a model for hallucination detection at the example level.

    For each example, if any token is marked as hallucinated (label=1),
    then the whole example is considered hallucinated. Otherwise, it is supported.

    :param model: The model to evaluate.
    :param dataloader: DataLoader providing the evaluation batches.
    :param device: Device on which to perform evaluation.
    :param verbose: If True, prints a detailed classification report.

    :return: A dict containing example-level metrics:
        {
            "supported": {"precision": float, "recall": float, "f1": float},
            "hallucinated": {"precision": float, "recall": float, "f1": float}
        }
    """
    model.eval()
    example_preds: list[int] = []
    example_labels: list[int] = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating (Example Level)", leave=False):
            # Move inputs to device. Note that `batch["labels"]`
            # can stay on CPU if you wish to avoid unnecessary transfers.
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits: torch.Tensor = outputs.logits  # Shape: [batch_size, seq_len, num_labels]
            predictions: torch.Tensor = torch.argmax(logits, dim=-1)  # Shape: [batch_size, seq_len]

            # Process each example in the batch separately.
            for i in range(batch["labels"].size(0)):
                sample_labels = batch["labels"][i]  # [seq_len]
                sample_preds = predictions[i].cpu()  # [seq_len]
                valid_mask = sample_labels != -100

                if valid_mask.sum().item() == 0:
                    # If no valid tokens, assign default class (supported).
                    true_example_label = 0
                    pred_example_label = 0
                else:
                    # Apply the valid mask and bring labels to CPU if needed.
                    sample_labels = sample_labels[valid_mask].cpu()
                    sample_preds = sample_preds[valid_mask]

                    # If any token in the sample is hallucinated (1), consider the whole sample hallucinated.
                    true_example_label = 1 if (sample_labels == 1).any().item() else 0
                    pred_example_label = 1 if (sample_preds == 1).any().item() else 0

                example_labels.append(true_example_label)
                example_preds.append(pred_example_label)

    precision, recall, f1, _ = precision_recall_fscore_support(
        example_labels, example_preds, labels=[0, 1], average=None, zero_division=0
    )

    results: dict[str, dict[str, float]] = {
        "supported": {  # Class 0
            "precision": float(precision[0]),
            "recall": float(recall[0]),
            "f1": float(f1[0]),
        },
        "hallucinated": {  # Class 1
            "precision": float(precision[1]),
            "recall": float(recall[1]),
            "f1": float(f1[1]),
        },
    }

    if verbose:
        report = classification_report(
            example_labels,
            example_preds,
            target_names=["Supported", "Hallucinated"],
            digits=4,
            zero_division=0,
        )
        print("\nDetailed Example-Level Classification Report:")
        print(report)
        results["classification_report"] = report

    return results


def evaluate_detector_char_level(
    detector: HallucinationDetector, samples: list[RagTruthSample]
) -> dict[str, float]:
    """Evaluate the HallucinationDetector at the character level.

    This function assumes that each sample is a dictionary containing:
      - "prompt": the prompt text.
      - "answer": the answer text.
      - "gold_spans": a list of dictionaries where each dictionary has "start" and "end" keys
                      indicating the character indices of the gold (human-labeled) span.

    It uses the detector (which should have been initialized with the appropriate model)
    to obtain predicted spans, compares those spans with the gold spans, and computes global
    precision, recall, and F1 based on character overlap.

    :param detector: The detector to evaluate.
    :param samples: A list of samples to evaluate.
    :return: A dictionary with global metrics: {"char_precision": ..., "char_recall": ..., "char_f1": ...}
    """
    total_overlap = 0
    total_predicted = 0
    total_gold = 0

    for sample in tqdm(samples, desc="Evaluating", leave=False):
        prompt = sample.prompt
        answer = sample.answer
        gold_spans = sample.labels
        predicted_spans = detector.predict_prompt(prompt, answer, output_format="spans")

        # Compute total predicted span length for this sample.
        sample_predicted_length = sum(pred["end"] - pred["start"] for pred in predicted_spans)
        total_predicted += sample_predicted_length

        # Compute total gold span length once for this sample.
        sample_gold_length = sum(gold["end"] - gold["start"] for gold in gold_spans)
        total_gold += sample_gold_length

        # Now, compute the overlap between each predicted span and each gold span.
        sample_overlap = 0
        for pred in predicted_spans:
            for gold in gold_spans:
                overlap_start = max(pred["start"], gold["start"])
                overlap_end = min(pred["end"], gold["end"])
                if overlap_end > overlap_start:
                    sample_overlap += overlap_end - overlap_start
        total_overlap += sample_overlap

    precision = total_overlap / total_predicted if total_predicted > 0 else 0
    recall = total_overlap / total_gold if total_gold > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    return {"precision": precision, "recall": recall, "f1": f1}