test_ggml.py 2.35 KB
Newer Older
1
import unittest
Matt Hoffner's avatar
Matt Hoffner committed
2
3
4
5
6
from unittest.mock import patch
import hashlib
import json
import os
import pickle
Matt Hoffner's avatar
Matt Hoffner committed
7
from lm_eval.models.ggml import GGMLLM
8

Matt Hoffner's avatar
Matt Hoffner committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
base_url = "https://matthoffner-ggml-llm-api.hf.space"

def ggml_completion_mock(base_url, **kwargs):
    # Generate a hash from the parameters
    hash_kwargs = {'base_url': base_url, **kwargs}
    hash = hashlib.sha256(json.dumps(hash_kwargs, sort_keys=True).encode('utf-8')).hexdigest()

    fname = f"./tests/testdata/ggml_test_{hash}.pkl"

    if os.path.exists(fname):
        with open(fname, "rb") as fh:
            return pickle.load(fh)
    else:
        print("The file does not exist, attempting to write...")  
        if 'stop' in kwargs:
            result = {"choices": [{"logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]}
        else:
            result = {"choices": [{"logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]}

        try:
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            print('Writing file at', fname)
            with open(fname, "wb") as fh:
                pickle.dump(result, fh)
            print('File written successfully')
        except Exception as e:
            print('File writing failed:', e)

        return result

Matt Hoffner's avatar
Matt Hoffner committed
39
class GGMLLMTest(unittest.TestCase):
Matt Hoffner's avatar
Matt Hoffner committed
40
    @patch('lm_eval.models.ggml.ggml_completion', new=ggml_completion_mock)
41
    def test_loglikelihood(self):
Matt Hoffner's avatar
Matt Hoffner committed
42
        lm = GGMLLM(base_url)
43

Matt Hoffner's avatar
Matt Hoffner committed
44
        lm.ggml_completion = ggml_completion_mock
45
46
47
48
49
50

        # Test loglikelihood
        requests = [("context1", "continuation1"), ("context2", "continuation2")]
        res = lm.loglikelihood(requests)

        # Assert the loglikelihood response is correct
Matt Hoffner's avatar
Matt Hoffner committed
51
        expected_res = [(logprob, True) for logprob in [-1.2345, -1.2345]]
52
53
        self.assertEqual(res, expected_res)

Matt Hoffner's avatar
Matt Hoffner committed
54
    @patch('lm_eval.models.ggml.ggml_completion', new=ggml_completion_mock)
55
    def test_greedy_until(self):
Matt Hoffner's avatar
Matt Hoffner committed
56
        lm = GGMLLM(base_url)
57

Matt Hoffner's avatar
Matt Hoffner committed
58
59
        # Set the ggml_completion method to the defined mock
        lm.ggml_completion = ggml_completion_mock
60
61
62
63
64
65

        # Test greedy_until
        requests = [("input1", {"until": "stop1"}), ("input2", {"until": "stop2"})]
        res = lm.greedy_until(requests)

        # Assert the greedy_until response is correct
Matt Hoffner's avatar
Matt Hoffner committed
66
        expected_res = []
67
68
69
70
        self.assertEqual(res, expected_res)

if __name__ == "__main__":
    unittest.main()