test_ggml.py 1.79 KB
Newer Older
1
2
import unittest
from unittest.mock import MagicMock
Matt Hoffner's avatar
Matt Hoffner committed
3
from lm_eval.models.ggml import GGMLLM
4

Matt Hoffner's avatar
Matt Hoffner committed
5
class GGMLLMTest(unittest.TestCase):
6
7
    def test_loglikelihood(self):
        base_url = "https://matthoffner-ggml-llm-api.hf.space"
Matt Hoffner's avatar
Matt Hoffner committed
8
        lm = GGMLLM(base_url)
9

Matt Hoffner's avatar
Matt Hoffner committed
10
11
        # Create a MagicMock object to mock ggml_completion
        ggml_completion_mock = MagicMock()
12
13

        # Set the return value for the mocked function
Matt Hoffner's avatar
Matt Hoffner committed
14
        ggml_completion_mock.return_value = {
15
16
17
18
            "logprob": -1.2345,
            "is_greedy": True
        }

Matt Hoffner's avatar
Matt Hoffner committed
19
20
        # Patch the ggml_completion function with the mocked function
        lm.ggml_completion = ggml_completion_mock
21
22
23
24
25
26
27
28
29
30
31

        # Test loglikelihood
        requests = [("context1", "continuation1"), ("context2", "continuation2")]
        res = lm.loglikelihood(requests)

        # Assert the loglikelihood response is correct
        expected_res = [(-1.2345, True), (-1.2345, True)]
        self.assertEqual(res, expected_res)

    def test_greedy_until(self):
        base_url = "https://matthoffner-ggml-llm-api.hf.space"
Matt Hoffner's avatar
Matt Hoffner committed
32
        lm = GGMLLM(base_url)
33

Matt Hoffner's avatar
Matt Hoffner committed
34
35
        # Define the ggml_completion method with the desired behavior
        def ggml_completion_mock(url, context, stop=None):
36
37
38
39
            if stop is not None:
                return {"text": f"generated_text{stop[-1]}"}
            return {"text": "generated_text"}

Matt Hoffner's avatar
Matt Hoffner committed
40
41
        # Set the ggml_completion method to the defined mock
        lm.ggml_completion = ggml_completion_mock
42
43
44
45
46
47
48
49
50
51
52

        # Test greedy_until
        requests = [("input1", {"until": "stop1"}), ("input2", {"until": "stop2"})]
        res = lm.greedy_until(requests)

        # Assert the greedy_until response is correct
        expected_res = ["generated_text1", "generated_text2"]
        self.assertEqual(res, expected_res)

if __name__ == "__main__":
    unittest.main()