test_ggml.py 2.84 KB
Newer Older
1
import unittest
Matt Hoffner's avatar
Matt Hoffner committed
2
3
4
5
6
from unittest.mock import patch
import hashlib
import json
import os
import pickle
Matt Hoffner's avatar
Matt Hoffner committed
7
from lm_eval.models.ggml import GGMLLM
8

Matt Hoffner's avatar
Matt Hoffner committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
base_url = "https://matthoffner-ggml-llm-api.hf.space"

def ggml_completion_mock(base_url, **kwargs):
    # Generate a hash from the parameters
    hash_kwargs = {'base_url': base_url, **kwargs}
    hash = hashlib.sha256(json.dumps(hash_kwargs, sort_keys=True).encode('utf-8')).hexdigest()

    fname = f"./tests/testdata/ggml_test_{hash}.pkl"

    if os.path.exists(fname):
        with open(fname, "rb") as fh:
            return pickle.load(fh)
    else:
        print("The file does not exist, attempting to write...")  
        if 'stop' in kwargs:
Matt Hoffner's avatar
Matt Hoffner committed
24
            result = {"choices": [{"text": f"generated text until {kwargs['stop']}", "logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]}
Matt Hoffner's avatar
Matt Hoffner committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
        else:
            result = {"choices": [{"logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]}

        try:
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            print('Writing file at', fname)
            with open(fname, "wb") as fh:
                pickle.dump(result, fh)
            print('File written successfully')
        except Exception as e:
            print('File writing failed:', e)

        return result

Matt Hoffner's avatar
Matt Hoffner committed
39

Matt Hoffner's avatar
Matt Hoffner committed
40
class GGMLLMTest(unittest.TestCase):
41
    @patch('lm_eval.models.ggml.GGMLLM.ggml_completion', side_effect=ggml_completion_mock)
Matt Hoffner's avatar
Matt Hoffner committed
42
    def test_loglikelihood(self, ggml_completion_mock):
Matt Hoffner's avatar
Matt Hoffner committed
43
        lm = GGMLLM(base_url)
44
45
46
47
48
49

        # Test loglikelihood
        requests = [("context1", "continuation1"), ("context2", "continuation2")]
        res = lm.loglikelihood(requests)

        # Assert the loglikelihood response is correct
Matt Hoffner's avatar
Matt Hoffner committed
50
        expected_res = [(logprob, True) for logprob in [-1.2345, -1.2345]]
51
52
        self.assertEqual(res, expected_res)

53
    @patch('lm_eval.models.ggml.GGMLLM.ggml_completion', side_effect=ggml_completion_mock)
Matt Hoffner's avatar
Matt Hoffner committed
54
    def test_greedy_until(self, ggml_completion_mock):
Matt Hoffner's avatar
Matt Hoffner committed
55
        lm = GGMLLM(base_url)
56
57
58
59
60
61

        # Test greedy_until
        requests = [("input1", {"until": "stop1"}), ("input2", {"until": "stop2"})]
        res = lm.greedy_until(requests)

        # Assert the greedy_until response is correct
Matt Hoffner's avatar
Matt Hoffner committed
62
        expected_res = ["generated text until stop1", "generated text until stop1"]
63
64
        self.assertEqual(res, expected_res)

65
    @patch('lm_eval.models.ggml.GGMLLM.ggml_completion', side_effect=ggml_completion_mock)
Matt Hoffner's avatar
Matt Hoffner committed
66
67
68
69
70
71
72
73
74
75
76
77
    def test_loglikelihood_rolling(self, ggml_completion_mock):
        lm = GGMLLM(base_url)

        # Test loglikelihood_rolling
        requests = ["input1", "input2"]
        res = lm.loglikelihood_rolling(requests)

        # Assert the loglikelihood_rolling response is correct
        expected_res = [(-1.2345, True), (-1.2345, True)]
        self.assertEqual(res, expected_res)


78
79
if __name__ == "__main__":
    unittest.main()