test_llama.py 1.81 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import unittest
from unittest.mock import MagicMock
from lm_eval.models.llama import LlamaLM

class LlamaLMTest(unittest.TestCase):
    def test_loglikelihood(self):
        base_url = "https://matthoffner-ggml-llm-api.hf.space"
        lm = LlamaLM(base_url)

        # Create a MagicMock object to mock llama_completion
        llama_completion_mock = MagicMock()

        # Set the return value for the mocked function
        llama_completion_mock.return_value = {
            "logprob": -1.2345,
            "is_greedy": True
        }

        # Patch the llama_completion function with the mocked function
        lm.llama_completion = llama_completion_mock

        # Test loglikelihood
        requests = [("context1", "continuation1"), ("context2", "continuation2")]
        res = lm.loglikelihood(requests)

        # Assert the loglikelihood response is correct
        expected_res = [(-1.2345, True), (-1.2345, True)]
        self.assertEqual(res, expected_res)

    def test_greedy_until(self):
        base_url = "https://matthoffner-ggml-llm-api.hf.space"
        lm = LlamaLM(base_url)

        # Define the llama_completion method with the desired behavior
        def llama_completion_mock(url, context, stop=None):
            if stop is not None:
                return {"text": f"generated_text{stop[-1]}"}
            return {"text": "generated_text"}

        # Set the llama_completion method to the defined mock
        lm.llama_completion = llama_completion_mock

        # Test greedy_until
        requests = [("input1", {"until": "stop1"}), ("input2", {"until": "stop2"})]
        res = lm.greedy_until(requests)

        # Assert the greedy_until response is correct
        expected_res = ["generated_text1", "generated_text2"]
        self.assertEqual(res, expected_res)




if __name__ == "__main__":
    unittest.main()