test_inference_pytest.py 6.43 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""Pytest tests for the inference module."""

from unittest.mock import MagicMock, patch

import pytest
import torch

from lettucedetect.models.inference import HallucinationDetector, TransformerDetector


@pytest.fixture
def mock_tokenizer():
    """Create a mock tokenizer for testing."""
    tokenizer = MagicMock()
    tokenizer.encode.return_value = [101, 102, 103, 104, 105]
    return tokenizer


@pytest.fixture
def mock_model():
    """Create a mock model for testing."""
    model = MagicMock()
    mock_output = MagicMock()
    mock_output.logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]]])
    model.return_value = mock_output
    return model


class TestHallucinationDetector:
    """Tests for the HallucinationDetector class."""

    def test_init_with_transformer_method(self):
        """Test initialization with transformer method."""
        with patch("lettucedetect.models.inference.TransformerDetector") as mock_transformer:
            detector = HallucinationDetector(method="transformer", model_path="dummy_path")
            mock_transformer.assert_called_once_with(model_path="dummy_path")
            assert isinstance(detector.detector, MagicMock)

    def test_init_with_invalid_method(self):
        """Test initialization with invalid method."""
        with pytest.raises(ValueError):
            HallucinationDetector(method="invalid_method")

    def test_predict(self):
        """Test predict method."""
        # Create a mock detector with the predict method
        mock_detector = MagicMock()
        mock_detector.predict.return_value = []

        with patch(
            "lettucedetect.models.inference.TransformerDetector", return_value=mock_detector
        ):
            detector = HallucinationDetector(method="transformer")
            context = ["This is a test context."]
            answer = "This is a test answer."
            question = "What is the test?"

            result = detector.predict(context, answer, question)

            # Check that the mock detector's predict method was called with the correct arguments
            mock_detector.predict.assert_called_once()
            call_args = mock_detector.predict.call_args[0]
            assert call_args[0] == context
            assert call_args[1] == answer
            assert call_args[2] == question
            assert call_args[3] == "tokens"

    def test_predict_prompt(self):
        """Test predict_prompt method."""
        # Create a mock detector with the predict_prompt method
        mock_detector = MagicMock()
        mock_detector.predict_prompt.return_value = []

        with patch(
            "lettucedetect.models.inference.TransformerDetector", return_value=mock_detector
        ):
            detector = HallucinationDetector(method="transformer")
            prompt = "This is a test prompt."
            answer = "This is a test answer."

            result = detector.predict_prompt(prompt, answer)

            # Check that the mock detector's predict_prompt method was called with the correct arguments
            mock_detector.predict_prompt.assert_called_once()
            call_args = mock_detector.predict_prompt.call_args[0]
            assert call_args[0] == prompt
            assert call_args[1] == answer
            assert call_args[2] == "tokens"


class TestTransformerDetector:
    """Tests for the TransformerDetector class."""

    @pytest.fixture(autouse=True)
    def setup(self, mock_tokenizer, mock_model):
        """Set up test fixtures."""
        self.mock_tokenizer = mock_tokenizer
        self.mock_model = mock_model

        # Patch the AutoTokenizer and AutoModelForTokenClassification
        self.tokenizer_patcher = patch(
            "lettucedetect.models.inference.AutoTokenizer.from_pretrained",
            return_value=self.mock_tokenizer,
        )
        self.model_patcher = patch(
            "lettucedetect.models.inference.AutoModelForTokenClassification.from_pretrained",
            return_value=self.mock_model,
        )

        self.mock_tokenizer_cls = self.tokenizer_patcher.start()
        self.mock_model_cls = self.model_patcher.start()

        yield

        self.tokenizer_patcher.stop()
        self.model_patcher.stop()

    def test_init(self):
        """Test initialization."""
        detector = TransformerDetector(model_path="dummy_path")

        self.mock_tokenizer_cls.assert_called_once_with("dummy_path")
        self.mock_model_cls.assert_called_once_with("dummy_path")
        assert detector.tokenizer == self.mock_tokenizer
        assert detector.model == self.mock_model
        assert detector.max_length == 4096

    def test_predict(self):
        """Test predict method."""

        # Create a proper mock encoding with input_ids as a tensor attribute
        class MockEncoding:
            def __init__(self):
                self.input_ids = torch.tensor([[101, 102, 103]])

        mock_encoding = MockEncoding()
        mock_labels = torch.tensor([0, 0, 0])
        mock_offsets = torch.tensor([[0, 0], [0, 1], [1, 2]])
        mock_answer_start = 1

        # Patch the _predict method to avoid the actual implementation
        with patch.object(TransformerDetector, "_predict", return_value=[]):
            detector = TransformerDetector(model_path="dummy_path")
            context = ["This is a test context."]
            answer = "This is a test answer."
            question = "What is the test?"

            result = detector.predict(context, answer, question)

            # Verify the result
            assert isinstance(result, list)

    def test_form_prompt_with_question(self):
        """Test _form_prompt method with a question."""
        detector = TransformerDetector(model_path="dummy_path")
        context = ["This is passage 1.", "This is passage 2."]
        question = "What is the test?"

        prompt = detector._form_prompt(context, question)

        # Check that the prompt contains the question and passages
        assert question in prompt
        assert "passage 1: This is passage 1." in prompt
        assert "passage 2: This is passage 2." in prompt

    def test_form_prompt_without_question(self):
        """Test _form_prompt method without a question (summary task)."""
        detector = TransformerDetector(model_path="dummy_path")
        context = ["This is a text to summarize."]

        prompt = detector._form_prompt(context, None)

        # Check that the prompt contains the text to summarize
        assert "This is a text to summarize." in prompt
        assert "Summarize" in prompt