Commit 3b2dedc6 authored by mashun1's avatar mashun1
Browse files

lettucedetect

parents
Pipeline #2497 failed with stages
in 0 seconds
__pycache__
*.bin
*.safetensors
outputs/
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10
\ No newline at end of file
MIT License
Copyright (c) 2025 KR Labs
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
recursive-include lettucedetect *
\ No newline at end of file
# LettuceDetect
轻量级大模型RAG幻觉输出检测模型。
## 论文
`LettuceDetect: A Hallucination Detection Framework for RAG Applications`
* https://arxiv.org/pdf/2502.17125
## 模型结构
概模型基于`ModernBert`。ModernBERT是 BERT 的直接替代品,是一种最先进的纯编码器变压器架构,它结合了对原始 BERT 模型的多项现代设计改进,例如它使用旋转位置嵌入 (RoPe) 来处理多达 8,192 个令牌的序列,取消填充优化以消除填充令牌上的浪费计算,以及用于增强表达力和交替注意力以实现更高效的注意力计算的 GeGLU 激活层。
![alt text](readme_imgs/arch.png)
## 算法原理
该算法基于Bert,在计算loss时将`question``context`遮住,并对`answer`进行二分类,以此达到token级别的`幻觉`判断。
## 环境配置
### Docker(方法一)
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10
docker run --shm-size 100g --network=host --name=detect --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
pip install -e .
pip install streamlit
### Dockerfile(方法二)
docker build -t <IMAGE_NAME>:<TAG> .
docker run --shm-size 100g --network=host --name=detect --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
pip install -e .
pip install streamlit
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
```
python: 3.10
torch: 2.4.1
```
2、其他非特殊库直接按照requirements.txt安装
```
pip install -e .
pip install streamlit
```
## 数据集
RAGTruth [github](https://github.com/ParticleMedia/RAGTruth/tree/main/dataset) | [SCnet]()
数据下载后保存在`data/ragtruth`目录(需自行创建),并使用如下脚本处理
```bash
python lettucedetect/preprocess/preprocess_ragtruth.py --input_dir data/ragtruth --output_dir data/ragtruth
```
## 训练
```bash
python scripts/train.py \
--data-path data/ragtruth/ragtruth_data.json \
--model-name /path/to/ModerBERT-base \
--output-dir outputs/hallucination_detector \
--batch-size 4 \
--epochs 6 \
--learning-rate 1e-5
```
## 推理
```bash
python -m scripts/evaluate \
--model_path /path/to/hallucination_detector \
--data_path data/ragtruth/ragtruth_data.json \
--evaluation_type example_level
```
### streamlit
```bash
streamlit run demo/streamlit_demo.py
```
注意,在运行前请修改文件中的模型地址。
## result
### 精度
||F1-score|
|:---:|:---:|
|DCU|0.5588|
|N卡|0.5544|
## 应用场景
### 算法类别
`文本检测`
### 热点应用行业
`教育,零售,金融,通信`
## 预训练权重
|model|url|
|:---:|:---:|
|lettucedect-base-modernbert-en-v1|[hf](https://hf-mirror.com/KRLabsOrg/lettucedect-base-modernbert-en-v1) \| [SCnet]() |
|lettucedect-large-modernbert-en-v1|[hf](https://hf-mirror.com/KRLabsOrg/lettucedect-large-modernbert-en-v1) \| [SCnet]() |
|ModernBERT-base|[hf](https://hf-mirror.com/answerdotai/ModernBERT-base/tree/main) \| [SCnet](http://113.200.138.88:18080/aimodels/answerdotai/ModernBERT-base)|
|ModernBERT-large|[hf](https://hf-mirror.com/answerdotai/ModernBERT-large) \| [SCnet](http://113.200.138.88:18080/aimodels/answerdotai/ModernBERT-large) |
## 源码仓库及问题反馈
* https://developer.sourcefind.cn/codes/modelzoo/lettucedetect_pytorch
## 参考资料
* https://github.com/KRLabsOrg/LettuceDetect
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/kovacs/projects/LettuceDetect/lettucedetect/models/inference.py:84: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" labels = torch.tensor(labels, device=self.device)\n",
"Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. Falling back to non-compiled mode.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predictions: [{'start': 31, 'end': 71, 'confidence': 0.9891982674598694, 'text': ' The population of France is 69 million.'}]\n"
]
}
],
"source": [
"from lettucedetect.models.inference import HallucinationDetector\n",
"\n",
"# For a transformer-based approach:\n",
"detector = HallucinationDetector(\n",
" method=\"transformer\", model_path=\"KRLabsOrg/lettucedect-base-modernbert-en-v1\"\n",
")\n",
"\n",
"contexts = [\n",
" \"France is a country in Europe. The capital of France is Paris. The population of France is 67 million.\",\n",
"]\n",
"question = \"What is the capital of France? What is the population of France?\"\n",
"answer = \"The capital of France is Paris. The population of France is 69 million.\"\n",
"\n",
"# Get span-level predictions indicating which parts of the answer are considered hallucinated.\n",
"predictions = detector.predict(\n",
" context=contexts, question=question, answer=answer, output_format=\"spans\"\n",
")\n",
"print(\"Predictions:\", predictions)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "lettucedetect",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import streamlit as st
import streamlit.components.v1 as components
from lettucedetect.models.inference import HallucinationDetector
def create_interactive_text(text: str, spans: list[dict[str, int | float]]) -> str:
"""Create interactive HTML with highlighting and hover effects.
:param text: The text to create the interactive text for.
:param spans: The spans to highlight.
:return: The interactive text.
"""
html_text = text
for span in sorted(spans, key=lambda x: x["start"], reverse=True):
span_text = text[span["start"] : span["end"]]
highlighted_span = f'<span class="hallucination" title="Confidence: {span["confidence"]:.3f}">{span_text}</span>'
html_text = html_text[: span["start"]] + highlighted_span + html_text[span["end"] :]
return f"""
<style>
.container {{
font-family: Arial, sans-serif;
font-size: 16px;
line-height: 1.6;
padding: 20px;
}}
.hallucination {{
background-color: rgba(255, 99, 71, 0.3);
padding: 2px;
border-radius: 3px;
cursor: help;
}}
.hallucination:hover {{
background-color: rgba(255, 99, 71, 0.5);
}}
</style>
<div class="container">{html_text}</div>
"""
def main():
st.set_page_config(page_title="Lettuce Detective")
st.image(
"https://github.com/KRLabsOrg/LettuceDetect/blob/main/assets/lettuce_detective.png?raw=true",
width=600,
)
st.title("Let Us Detect Your Hallucinations")
@st.cache_resource
def load_detector():
return HallucinationDetector(
method="transformer",
model_path="/home/LettuceDetect/ckpts/lettucedect-large-modernbert-en-v1",
)
detector = load_detector()
context = st.text_area(
"Context",
"France is a country in Europe. The capital of France is Paris. The population of France is 67 million.",
height=100,
)
question = st.text_area(
"Question",
"What is the capital of France? What is the population of France?",
height=100,
)
answer = st.text_area(
"Answer",
"The capital of France is Paris. The population of France is 69 million.",
height=100,
)
if st.button("Detect Hallucinations"):
predictions = detector.predict(
context=[context], question=question, answer=answer, output_format="spans"
)
html_content = create_interactive_text(answer, predictions)
components.html(html_content, height=200)
if __name__ == "__main__":
main()
# LettuceDetect: A Hallucination Detection Framework for RAG Applications
<p align="center">
<img src="https://github.com/KRLabsOrg/LettuceDetect/blob/main/assets/lettuce_detective.png?raw=true" alt="LettuceDetect Logo" width="400"/>
<br><em>Because even AI needs a reality check! 🥬</em>
</p>
## TL;DR
We present **LettuceDetect**, a lightweight hallucination detector for Retrieval-Augmented Generation (RAG) pipelines. It is an **encoder-based** model built on [ModernBERT](https://github.com/AnswerDotAI/ModernBERT), released under the MIT license with ready-to-use Python packages and pretrained models.
- **What**: LettuceDetect is a token-level detector that flags unsupported segments in LLM answers. 🥬
- **How**: Trained on [RAGTruth](https://aclanthology.org/2024.acl-long.585/) (18k examples), leveraging ModernBERT for context lengths up to 4k tokens. 🚀
- **Why**: It addresses (1) the context-window limits in prior encoder-only models, and (2) the high compute costs of LLM-based detectors. ⚖️
- **Highlights**:
- Beats prior encoder-based models (e.g., [Luna](https://aclanthology.org/2025.coling-industry.34/)) on RAGTruth. ✅
- Surpasses fine-tuned Llama-2-13B [2] at a fraction of the size, and is highly efficient at inference. ⚡️
- Entirely **open-source** with an MIT license. 🔓
**LettuceDetect** keeps your RAG framewok **fresh** by spotting **rotten** parts of your LLM's outputs. 😊
## Quick Links
- **GitHub**: [github.com/KRLabsOrg/LettuceDetect](https://github.com/KRLabsOrg/LettuceDetect)
- **PyPI**: [pypi.org/project/lettucedetect](https://pypi.org/project/lettucedetect/)
- **arXiv Paper**: [2502.17125](https://arxiv.org/abs/2502.17125)
- **Hugging Face Models**:
- [Base Model](https://huggingface.co/KRLabsOrg/lettucedect-base-modernbert-en-v1)
- [Large Model](https://huggingface.co/KRLabsOrg/lettucedect-large-modernbert-en-v1)
- **Streamlit Demo**: Visit our [Hugging Face Space](https://huggingface.co/spaces/KRLabsOrg/LettuceDetect) or run locally per the GitHub instructions.
---
## Get Going
Install the package:
```bash
pip install lettucedetect
```
Then, you can use the package as follows:
```python
from lettucedetect.models.inference import HallucinationDetector
# For a transformer-based approach:
detector = HallucinationDetector(
method="transformer", model_path="KRLabsOrg/lettucedect-base-modernbert-en-v1"
)
contexts = ["France is a country in Europe. The capital of France is Paris. The population of France is 67 million.",]
question = "What is the capital of France? What is the population of France?"
answer = "The capital of France is Paris. The population of France is 69 million."
# Get span-level predictions indicating which parts of the answer are considered hallucinated.
predictions = detector.predict(context=contexts, question=question, answer=answer, output_format="spans")
print("Predictions:", predictions)
# Predictions: [{'start': 31, 'end': 71, 'confidence': 0.9944414496421814, 'text': ' The population of France is 69 million.'}]
```
## Why LettuceDetect?
Large Language Models (LLMs) have made considerable advancements in NLP tasks, like GPT-4 [4], the Llama-3 models [5], or Mistral [6] (and many more). Despite the success of LLMs, **hallucinations** remain a key obstacle deploying LLMs in high-stakes scenarios (such as in healthcare or legal) [7,8].
**Retrieval-Augmented Generation (RAG)** attempts to mitigate hallucinations by grounding an LLM's responses in retrieved documents, providing external knowledge that the model can reference [9]. But even though RAG is a powerful method to reduce hallucinations, LLMs still suffer from hallucinations in these settings [1]. Hallucinations are information in the output that is nonsensical, factually incorrect, or inconsistent with the retrieved context [8]. Ji et al. [10] categorizes hallucinations into:
- Intrinsic Hallucinations: Stemming from the model's preexisting internal knowledge.
- Extrinsic Hallucinations: Occurring when the answer conflicts with the context or references provided
While RAG approaches can mitigate intrinsic hallucinations, they are not immune to extrinsic hallucinations. Sun et al. [11] showed that models tend to prioritize their intrinsic knowledge over the external context. As LLMs remain prone to hallucinations, their applications in critical domains e.g. medical or legal, can be still flawed.
### Current Solutions for Hallucination Detection
Current solutions for hallucination detection can be categorized into different categories based on the approach they take:
1. **Prompt based detectors**
These methods (e.g., [RAGAS](https://github.com/explodinggradients/ragas), [Trulens](https://github.com/truera/trulens), [ARES](https://github.com/stanford-futuredata/ARES)) typically leverage zero-shot or few-shot prompts to detect hallucinations. They often rely on large LLMs (like GPT-4) and employ strategies such as SelfCheckGPT [12], LM vs. LM [13], or Chainpoll [14]. While often effective, they can be computationally expensive due to repeated LLM calls.
2. **Fine-tuned LLM detectors**
Large models (e.g., Llama-2, Llama-3) can be fine-tuned for hallucination detection [1,15]. This can yield high accuracy (as shown by the RAGTruth authors using Llama-2-13B or the RAG-HAT work on Llama-3-8B) but is resource-intensive to train and deploy. Inference costs also tend to be high due to their size and slower speeds.
3. **Encoder-based detectors**
Models like [Luna [2]](https://aclanthology.org/2025.coling-industry.34/) rely on a BERT-style encoder (often limited to 512 tokens) for token-level classification. These methods are generally more efficient than running a full LLM at inference but are constrained by **short context windows** and attention mechanisms optimized for smaller inputs.
### ModernBERT for long context
ModernBERT [3] is a drop in replacement for BERT, and is a state of the art encoder-only transformers architecture that incorporates several modern design improvements over the original BERT model such as it uses Rotary Positional Embeddings ([RoPe](https://huggingface.co/blog/designing-positional-encoding)) to handle sequences of up to 8,192 tokens, [unpadding optimization](https://arxiv.org/abs/2208.08124) to eliminate wasted computation on padding tokens, and [GeGLU](https://arxiv.org/abs/2002.05202) activation layers for enhanced expressiveness and [alternating attention](https://arxiv.org/abs/2004.05150v2) for more efficient attention computation.
**LettuceDetect** capitalizes on ModernBERT’s extended context window to build a token-level classifier for hallucination detection. This approach sidesteps many limitations of older BERT-based models (e.g., short context bounds) and avoids the inference overhead of large LLM-based detectors. Our experiments show that LettuceDetect outperforms other encoder-based systems while remaining **competitive with fine-tuned LLM detectors** at a fraction of their computational cost.
## Data
[RAGTruth](https://aclanthology.org/2024.acl-long.585/) is the first large-scale benchmark specifically designed to evaluate **hallucination detection** in Retrieval-Augmented Generation (RAG) settings. It contains **18,000** annotated examples spanning multiple tasks:
- **Question Answering (QA)**: Sampled from the [MS MARCO](https://microsoft.github.io/msmarco/) dataset, where up to three documents are retrieved for each question, and an LLM is prompted to produce an answer.
- **Data-to-Text Generation**: Based on the [Yelp Open Dataset](https://www.yelp.com/dataset/), where LLMs generate reviews for sampled businesses.
- **News Summarization**: Uses random documents from [CNN/DailyMail](https://huggingface.co/datasets/cnn_dailymail), prompting LLMs to provide summaries.
Each example in RAGTruth includes:
1. A **question** and **prompt**,
2. One or more **reference passages** or **context documents**,
3. An **LLM-generated response**,
4. **Human annotations** at the span level indicating whether any part of the answer is factually unsupported or contradictory to the reference passages.
Below is an illustrative sample from RAGTruth, showing an instance of hallucination:
---
**QUESTION**
How to explain quartiles?
**REFERENCE**
> **Passage 2:**
> Second quartile (Q2) which is more commonly known as median splits the data in half (50%).
> Median divides the data into a lower half and an upper half.
> Third quartile (Q3), also known as upper quartile, *splits lowest 75% (or highest 25%)* of data.
> *(...)*
**RESPONSE**
> The first quartile (Q1) splits the lowest 25% of the data, while the second quartile (Q2) splits the data into two equal halves, with the median being the middle value of the lower half. Finally, the third quartile (Q3) splits *the highest 75% of the data.*
> *(...)*
**ANNOTATION**
- **Span**: “highest 75%”
- **Type**: Evident Conflict
- **Reason**:
- Original text: “lowest 75% (or highest 25%)”
- Generative text: “highest 75%”
---
This example showcases a RAGTruth annotation involving specific spans in the answer that conflict with the reference passages, providing both the **type** of hallucination (e.g., Evident Conflict) and a human-readable **rationale**. Our method only uses binary annotations and does not use the type of hallucination annotations.
## Method
<p align="center">
<img src="https://github.com/KRLabsOrg/LettuceDetect/blob/main/assets/architecture.png?raw=true" alt="LettuceDetect Architecture" width="800"/>
</p>
**A high-level depiction of LettuceDetect.** Here, an example **Question**, **Context**, and **Answer** triplet is processed. First, the text is tokenized, after which LettuceDetect performs token-level classification. Tokens from both the question and context are **masked** (indicated by the red line in the figure) to exclude them from the loss function. Each token in the answer receives a probability indicating whether it is hallucinated or supported. For *span-level* detection, we merge consecutive tokens with hallucination probabilities above 0.5 into a single predicted span.
---
We train **ModernBERT-base** and **ModernBERT-large** variants as token-classification models on the RAGTruth dataset. The input to the model is a concatenation of **Context**, **Question**, and **Answer** segments, with specialized tokens \([CLS]\) (for the context) and \([SEP]\) (as separators). We limit the sequence length to 4,096 tokens for computational feasibility, though ModernBERT can theoretically handle up to 8,192 tokens.
### Tokenization and Data Processing
- **Tokenizer**: We employ _AutoTokenizer_ from the Transformers library to handle subword tokenization, inserting `[CLS]` and `[SEP]` appropriately.
- **Labeling**:
- Context/question tokens are *masked* (i.e., assigned a label of `-100` in PyTorch) so that they do not contribute to the loss.
- Each *answer* token receives a label of **0** (supported) or **1** (hallucinated).
### Model Architecture
Our models build on Hugging Face’s _AutoModelForTokenClassification_, using **ModernBERT** as the encoder and a classification head on top. Unlike some previous encoder-based approaches (e.g., ones pre-trained on NLI tasks), our method uses *only* ModernBERT with no additional pretraining stage.
### Training Configuration
- **Optimizer**: AdamW, with a learning rate of 1 * 10^-5 and weight decay of 0.01.
- **Hardware**: Single NVIDIA A100 GPU.
- **Epochs**: 6 total training epochs.
- **Batching**:
- Batch size of 8,
- Data loading with PyTorch _DataLoader_ (shuffling enabled),
- Dynamic padding via _DataCollatorForTokenClassification_ to handle variable-length sequences efficiently.
During training, we monitor **token-level F1** scores on a validation split, saving checkpoints using the _safetensors_ format. Once training is complete, we upload the best-performing models to Hugging Face for public access.
At inference time, the model outputs a probability of hallucination for each token in the answer. We aggregate consecutive tokens exceeding a 0.5 threshold to produce **span-level** predictions, indicating exactly which segments of the answer are likely to be hallucinated. The figure above illustrates this workflow.
Next, we provide a more detailed evaluation of the model’s performance.
## Results
We evaluate our models on the RAGTruth test set across all task types (Question Answering, Data-to-Text, and Summarization). For each example, RAGTruth includes manually annotated spans indicating hallucinated content.
### Example-Level Results
We first assess the **example-level** question: *Does the generated answer contain any hallucination at all?* Our large model (**lettucedetect-large-v1**) attains an **overall F1 score of 79.22%**, surpassing:
- **GPT-4** (63.4%),
- **Luna** (65.4%) (the previous state of the art encoder-based model),
- **Fine-tuned Llama-2-13B** (78.7%) as presented in the RAGTruth paper.
It is second only to the fine-tuned **Llama-3-8B** from the **RAG-HAT** paper [15] (83.9%), but LettuceDetect is significantly **smaller** and **faster** to run. Meanwhile, our base model (**lettucedetect-base-v1**) remains highly competitive while using fewer parameters.
<p align="center">
<img src="https://github.com/KRLabsOrg/LettuceDetect/blob/main/assets/example_level_lettucedetect.png?raw=true" alt="Example-level Results" width="800"/>
</p>
Above is a comparison table illustrating how LettuceDetect aligns against both prompt-based methods (e.g., GPT-4) and alternative encoder-based solutions (e.g., Luna). Overall, **lettucedetect-large-v1** and **lettucedect-base-v1** are very performant models, while being very effective in inference settings.
### Span-Level Results
Beyond detecting if an answer contains hallucinations, we also examine LettuceDetect’s ability to identify the **exact spans** of unsupported content. Here, LettuceDetect achieves **state-of-the-art** results among models that have reported span-level performance, substantially outperforming the fine-tuned Llama-2-13B model from the RAGTruth paper [1] and other baselines.
<p align="center">
<img src="https://github.com/KRLabsOrg/LettuceDetect/blob/main/assets/span_level_lettucedetect.png?raw=true" alt="Span-level Results" width="800"/>
</p>
Most methods, like RAG-HAT [15], do not report span-level metrics, so we do not compare to them here.
### Inference Efficiency
Both **lettucedetect-base-v1** and **lettucedetect-large-v1** require fewer parameters than typical LLM-based detectors (e.g., GPT-4 or Llama-3-8B) and can process **30–60 examples per second** on a single NVIDIA A100 GPU. This makes them practical for industrial workloads, real-time user-facing systems, and resource-constrained environments.
Overall, these results show that **LettuceDetect** has a good balance: it achieves near state-of-the-art accuracy at a fraction of the size and cost compared to large LLM-based judges, while offering precise, token-level hallucination detection.
## Conclusion
We introduced **LettuceDetect**, a lightweight and efficient framework for hallucination detection in RAG systems. By utilizing ModernBERT’s extended context capabilities, our models achieve strong performance on the RAGTruth benchmark while retaining **high inference efficiency**. This work lays the groundwork for future research directions, such as expanding to additional datasets, supporting multiple languages, and exploring more advanced architectures. Even at this stage, LettuceDetect demonstrates that **effective hallucination detection** can be achieved using **lean, purpose-built** encoder-based models.
## Citation
If you find this work useful, please cite it as follows:
```bibtex
@misc{Kovacs:2025,
title={LettuceDetect: A Hallucination Detection Framework for RAG Applications},
author={Ádám Kovács and Gábor Recski},
year={2025},
eprint={2502.17125},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2502.17125},
}
```
Also, if you use our code, please don't forget to give us a star ⭐ on our GitHub repository [here](https://github.com/KRLabsOrg/LettuceDetect).
## References
[1] [Niu et al., 2024, RAGTruth: A Dataset for Hallucination Detection in Retrieval-Augmented Generation](https://aclanthology.org/2024.acl-long.585/)
[2] [Luna: A Simple and Effective Encoder-Based Model for Hallucination Detection in Retrieval-Augmented Generation](https://aclanthology.org/2025.coling-industry.34/)
[3] [ModernBERT: A Modern BERT Model for Long-Context Processing](https://huggingface.co/blog/modernbert)
[4] [GPT-4 report](https://arxiv.org/abs/2303.08774)
[5] [Llama-3 report](https://arxiv.org/abs/2307.09288)
[6] [Mistral 7B](https://arxiv.org/abs/2310.06825)
[7] [Kaddour et al., 2023, Challenges and Applications of Large Language Models](https://arxiv.org/pdf/2307.10169)
[8] [Huang et al., 2025, A Survey on Hallucination in Large Language Models: Principles, Taxonomy, Challenges, and Open Questions](https://dl.acm.org/doi/10.1145/3703155)
[9] [Gao et al., 2024, Retrieval-Augmented Generation for Large Language Models: A Survey](https://arxiv.org/abs/2312.10997)
[10] [Ji et al., 2023, Survey of Hallucination in Natural Language Generation](https://doi.org/10.1145/3571730)
[11] [Sun et al., 2025, ReDeEP: Detecting Hallucination in Retrieval-Augmented Generation via Mechanistic Interpretability](https://arxiv.org/abs/2410.11414)
[12] [Manakul et al., 2023, SelfCheckGPT: Zero-Resource Black-Box Hallucination Detection for Generative Large Language Models](https://aclanthology.org/2023.emnlp-main.557/)
[13] [Cohen et al., 2023, LM vs LM: Detecting Factual Errors via Cross Examination](https://aclanthology.org/2023.emnlp-main.778/)
[14] [Friel et al., 2023, Chainpoll: A high efficacy method for LLM hallucination detection](https://arxiv.org/abs/2310.18344)
[15] [Song et al., 2024, RAG-HAT: A Hallucination-Aware Tuning Pipeline for {LLM} in Retrieval-Augmented Generation](https://aclanthology.org/2024.emnlp-industry.113/)
[16] [Devlin et al., 2019, BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805)
\ No newline at end of file
icon.png

61.5 KB

import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from lettucedetect.preprocess.preprocess_ragbench import RagBenchSample
class RagBenchDataset(Dataset):
"""Dataset for RAG Bench data."""
def __init__(
self,
samples: list[RagBenchSample],
tokenizer: AutoTokenizer,
max_length: int = 4096,
):
"""Initialize the dataset.
:param samples: List of RagBenchSample objects.
:param tokenizer: Tokenizer to use for encoding the data.
:param max_length: Maximum length of the input sequence.
"""
self.samples = samples
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self) -> int:
return len(self.samples)
@classmethod
def prepare_tokenized_input(
cls,
tokenizer: AutoTokenizer,
context: str,
answer: str,
max_length: int = 4096,
) -> tuple[dict[str, torch.Tensor], list[int], torch.Tensor, int]:
"""Tokenizes the context and answer together, computes the answer start token index,
and initializes a labels list (using -100 for context tokens and 0 for answer tokens).
:param tokenizer: The tokenizer to use.
:param context: The context string.
:param answer: The answer string.
:param max_length: Maximum input sequence length.
:return: A tuple containing:
- encoding: A dict of tokenized inputs without offset mapping.
- labels: A list of initial token labels.
- offsets: Offset mappings for each token (as a tensor of shape [seq_length, 2]).
- answer_start_token: The index where answer tokens begin.
"""
encoding = tokenizer(
context,
answer,
truncation=True,
max_length=max_length,
return_offsets_mapping=True,
return_tensors="pt",
add_special_tokens=True,
)
offsets = encoding.pop("offset_mapping")[0] # shape: (seq_length, 2)
# Compute start token index.
prompt_tokens = tokenizer.encode(context, add_special_tokens=False)
# Tokenization adds [CLS] at the start and [SEP] after the prompt.
answer_start_token = 1 + len(prompt_tokens) + 1 # [CLS] + prompt tokens + [SEP]
# Initialize labels: -100 for tokens before the asnwer, 0 for tokens in the answer.
labels = [-100] * encoding["input_ids"].shape[1]
return encoding, labels, offsets, answer_start_token
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
"""Get an item from the dataset.
:param idx: Index of the item to get.
:return: Dictionary with input IDs, attention mask, and labels.
"""
sample = self.samples[idx]
# Use the shared class method to perform tokenization and initial label setup.
encoding, labels, offsets, answer_start = RagBenchDataset.prepare_tokenized_input(
self.tokenizer, sample.prompt, sample.answer, self.max_length
)
# Adjust the token labels based on the annotated hallucination spans.
# Compute the character offset of the first answer token.
answer_char_offset = offsets[answer_start][0] if answer_start < len(offsets) else None
for i in range(answer_start, encoding["input_ids"].shape[1]):
token_start, token_end = offsets[i]
# Adjust token offsets relative to answer text.
token_abs_start = (
token_start - answer_char_offset if answer_char_offset is not None else token_start
)
token_abs_end = (
token_end - answer_char_offset if answer_char_offset is not None else token_end
)
# Default label is 0 (supported content).
# If token overlaps any annotated hallucination span, mark it as hallucinated (1).
for ann in sample.labels:
if token_abs_end > ann["start"] and token_abs_start < ann["end"]:
token_label = 1
break
labels = torch.tensor(labels, dtype=torch.long)
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"labels": labels,
}
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from lettucedetect.preprocess.preprocess_ragtruth import RagTruthSample
class RagTruthDataset(Dataset):
"""Dataset for RAG truth data."""
def __init__(
self,
samples: list[RagTruthSample],
tokenizer: AutoTokenizer,
max_length: int = 4096,
):
"""Initialize the dataset.
:param samples: List of RagTruthSample objects.
:param tokenizer: Tokenizer to use for encoding the data.
:param max_length: Maximum length of the input sequence.
"""
self.samples = samples
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self) -> int:
return len(self.samples)
@classmethod
def prepare_tokenized_input(
cls,
tokenizer: AutoTokenizer,
context: str,
answer: str,
max_length: int = 4096,
) -> tuple[dict[str, torch.Tensor], list[int], torch.Tensor, int]:
"""Tokenizes the context and answer together, computes the answer start token index,
and initializes a labels list (using -100 for context tokens and 0 for answer tokens).
:param tokenizer: The tokenizer to use.
:param context: The context string.
:param answer: The answer string.
:param max_length: Maximum input sequence length.
:return: A tuple containing:
- encoding: A dict of tokenized inputs without offset mapping.
- labels: A list of initial token labels.
- offsets: Offset mappings for each token (as a tensor of shape [seq_length, 2]).
- answer_start_token: The index where answer tokens begin.
"""
encoding = tokenizer(
context,
answer,
truncation=True,
max_length=max_length,
return_offsets_mapping=True,
return_tensors="pt",
add_special_tokens=True,
)
# Extract and remove offset mapping.
offsets = encoding.pop("offset_mapping")[0] # shape: (seq_length, 2)
# Compute answer start token index.
prompt_tokens = tokenizer.encode(context, add_special_tokens=False)
# Tokenization adds [CLS] at the start and [SEP] after the prompt.
answer_start_token = 1 + len(prompt_tokens) + 1 # [CLS] + prompt tokens + [SEP]
# Initialize labels: -100 for tokens before the answer, 0 for tokens in the answer.
labels = [-100] * encoding["input_ids"].shape[1]
return encoding, labels, offsets, answer_start_token
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
"""Get an item from the dataset.
:param idx: Index of the item to get.
:return: Dictionary with input IDs, attention mask, and labels.
"""
sample = self.samples[idx]
# Use the shared class method to perform tokenization and initial label setup.
encoding, labels, offsets, answer_start = RagTruthDataset.prepare_tokenized_input(
self.tokenizer, sample.prompt, sample.answer, self.max_length
)
# Adjust the token labels based on the annotated hallucination spans.
# Compute the character offset of the first answer token.
answer_char_offset = offsets[answer_start][0] if answer_start < len(offsets) else None
for i in range(answer_start, encoding["input_ids"].shape[1]):
token_start, token_end = offsets[i]
# Adjust token offsets relative to answer text.
token_abs_start = (
token_start - answer_char_offset if answer_char_offset is not None else token_start
)
token_abs_end = (
token_end - answer_char_offset if answer_char_offset is not None else token_end
)
# Default label is 0 (supported content).
token_label = 0
# If token overlaps any annotated hallucination span, mark it as hallucinated (1).
for ann in sample.labels:
if token_abs_end > ann["start"] and token_abs_start < ann["end"]:
token_label = 1
break
labels[i] = token_label
labels = torch.tensor(labels, dtype=torch.long)
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"labels": labels,
}
import torch
from sklearn.metrics import classification_report, precision_recall_fscore_support
from torch.nn import Module
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from lettucedetect.datasets.ragtruth import RagTruthSample
from lettucedetect.models.inference import HallucinationDetector
def evaluate_model(
model: Module,
dataloader: DataLoader,
device: torch.device,
verbose: bool = True,
) -> dict[str, dict[str, float]]:
"""Evaluate a model for hallucination detection.
:param model: The model to evaluate.
:param dataloader: The data loader to use for evaluation.
:param device: The device to use for evaluation.
:param verbose: If True, print the evaluation metrics.
:return: A dictionary containing the evaluation metrics.
{
"supported": {"precision": float, "recall": float, "f1": float},
"hallucinated": {"precision": float, "recall": float, "f1": float}
}
"""
model.eval()
all_preds: list[int] = []
all_labels: list[int] = []
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating", leave=False):
outputs = model(
batch["input_ids"].to(device),
attention_mask=batch["attention_mask"].to(device),
)
logits: torch.Tensor = outputs.logits
predictions = torch.argmax(logits, dim=-1)
# Only evaluate on tokens that have labels (not -100)
mask = batch["labels"] != -100
predictions = predictions[mask].cpu().numpy()
labels = batch["labels"][mask].cpu().numpy()
all_preds.extend(predictions.tolist())
all_labels.extend(labels.tolist())
precision, recall, f1, _ = precision_recall_fscore_support(
all_labels, all_preds, labels=[0, 1], average=None
)
results: dict[str, dict[str, float]] = {
"supported": { # Class 0
"precision": float(precision[0]),
"recall": float(recall[0]),
"f1": float(f1[0]),
},
"hallucinated": { # Class 1
"precision": float(precision[1]),
"recall": float(recall[1]),
"f1": float(f1[1]),
},
}
if verbose:
report = classification_report(
all_labels, all_preds, target_names=["Supported", "Hallucinated"], digits=4
)
print("\nDetailed Classification Report:")
print(report)
results["classification_report"] = report
return results
def print_metrics(metrics: dict[str, dict[str, float]]) -> None:
"""Print evaluation metrics in a readable format.
:param metrics: A dictionary containing the evaluation metrics.
:return: None
"""
print("\nEvaluation Results:")
print("\nHallucination Detection (Class 1):")
print(f" Precision: {metrics['hallucinated']['precision']:.4f}")
print(f" Recall: {metrics['hallucinated']['recall']:.4f}")
print(f" F1: {metrics['hallucinated']['f1']:.4f}")
print("\nSupported Content (Class 0):")
print(f" Precision: {metrics['supported']['precision']:.4f}")
print(f" Recall: {metrics['supported']['recall']:.4f}")
print(f" F1: {metrics['supported']['f1']:.4f}")
def evaluate_model_example_level(
model: Module,
dataloader: DataLoader,
device: torch.device,
verbose: bool = True,
) -> dict[str, dict[str, float]]:
"""Evaluate a model for hallucination detection at the example level.
For each example, if any token is marked as hallucinated (label=1),
then the whole example is considered hallucinated. Otherwise, it is supported.
:param model: The model to evaluate.
:param dataloader: DataLoader providing the evaluation batches.
:param device: Device on which to perform evaluation.
:param verbose: If True, prints a detailed classification report.
:return: A dict containing example-level metrics:
{
"supported": {"precision": float, "recall": float, "f1": float},
"hallucinated": {"precision": float, "recall": float, "f1": float}
}
"""
model.eval()
example_preds: list[int] = []
example_labels: list[int] = []
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating (Example Level)", leave=False):
# Move inputs to device. Note that `batch["labels"]`
# can stay on CPU if you wish to avoid unnecessary transfers.
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
outputs = model(input_ids, attention_mask=attention_mask)
logits: torch.Tensor = outputs.logits # Shape: [batch_size, seq_len, num_labels]
predictions: torch.Tensor = torch.argmax(logits, dim=-1) # Shape: [batch_size, seq_len]
# Process each example in the batch separately.
for i in range(batch["labels"].size(0)):
sample_labels = batch["labels"][i] # [seq_len]
sample_preds = predictions[i].cpu() # [seq_len]
valid_mask = sample_labels != -100
if valid_mask.sum().item() == 0:
# If no valid tokens, assign default class (supported).
true_example_label = 0
pred_example_label = 0
else:
# Apply the valid mask and bring labels to CPU if needed.
sample_labels = sample_labels[valid_mask].cpu()
sample_preds = sample_preds[valid_mask]
# If any token in the sample is hallucinated (1), consider the whole sample hallucinated.
true_example_label = 1 if (sample_labels == 1).any().item() else 0
pred_example_label = 1 if (sample_preds == 1).any().item() else 0
example_labels.append(true_example_label)
example_preds.append(pred_example_label)
precision, recall, f1, _ = precision_recall_fscore_support(
example_labels, example_preds, labels=[0, 1], average=None, zero_division=0
)
results: dict[str, dict[str, float]] = {
"supported": { # Class 0
"precision": float(precision[0]),
"recall": float(recall[0]),
"f1": float(f1[0]),
},
"hallucinated": { # Class 1
"precision": float(precision[1]),
"recall": float(recall[1]),
"f1": float(f1[1]),
},
}
if verbose:
report = classification_report(
example_labels,
example_preds,
target_names=["Supported", "Hallucinated"],
digits=4,
zero_division=0,
)
print("\nDetailed Example-Level Classification Report:")
print(report)
results["classification_report"] = report
return results
def evaluate_detector_char_level(
detector: HallucinationDetector, samples: list[RagTruthSample]
) -> dict[str, float]:
"""Evaluate the HallucinationDetector at the character level.
This function assumes that each sample is a dictionary containing:
- "prompt": the prompt text.
- "answer": the answer text.
- "gold_spans": a list of dictionaries where each dictionary has "start" and "end" keys
indicating the character indices of the gold (human-labeled) span.
It uses the detector (which should have been initialized with the appropriate model)
to obtain predicted spans, compares those spans with the gold spans, and computes global
precision, recall, and F1 based on character overlap.
:param detector: The detector to evaluate.
:param samples: A list of samples to evaluate.
:return: A dictionary with global metrics: {"char_precision": ..., "char_recall": ..., "char_f1": ...}
"""
total_overlap = 0
total_predicted = 0
total_gold = 0
for sample in tqdm(samples, desc="Evaluating", leave=False):
prompt = sample.prompt
answer = sample.answer
gold_spans = sample.labels
predicted_spans = detector.predict_prompt(prompt, answer, output_format="spans")
# Compute total predicted span length for this sample.
sample_predicted_length = sum(pred["end"] - pred["start"] for pred in predicted_spans)
total_predicted += sample_predicted_length
# Compute total gold span length once for this sample.
sample_gold_length = sum(gold["end"] - gold["start"] for gold in gold_spans)
total_gold += sample_gold_length
# Now, compute the overlap between each predicted span and each gold span.
sample_overlap = 0
for pred in predicted_spans:
for gold in gold_spans:
overlap_start = max(pred["start"], gold["start"])
overlap_end = min(pred["end"], gold["end"])
if overlap_end > overlap_start:
sample_overlap += overlap_end - overlap_start
total_overlap += sample_overlap
precision = total_overlap / total_predicted if total_predicted > 0 else 0
recall = total_overlap / total_gold if total_gold > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
return {"precision": precision, "recall": recall, "f1": f1}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment