Commit b56dee4e authored by Matt Hoffner's avatar Matt Hoffner
Browse files

add loglikelihood_rolling

parent 97d31082
import requests import requests
import logging import logging
import time import time
from lm_eval.utils import Reorderer
from lm_eval.base import BaseLM
from tqdm import tqdm from tqdm import tqdm
from requests.exceptions import RequestException from requests.exceptions import RequestException
from lm_eval.utils import Reorderer
from lm_eval.base import BaseLM
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def ggml_completion(base_url, retries=3, delay=5, **kwargs): def ggml_completion(base_url, retries=3, delay=5, **kwargs):
...@@ -74,6 +75,20 @@ class GGMLLM(BaseLM): ...@@ -74,6 +75,20 @@ class GGMLLM(BaseLM):
res.append(None) # Add default value in case of error res.append(None) # Add default value in case of error
return reorderer.get_original(res) return reorderer.get_original(res)
def loglikelihood_rolling(self, requests):
results = []
for request in requests:
logprobs = []
for i in range(0, len(request), self.max_length):
chunk = request[i:i+self.max_length]
chunk_loglikelihood = self.loglikelihood([(chunk, request[i+1:i+self.max_length+1])])
logprobs.extend(chunk_loglikelihood)
avg_loglikelihood = sum([logprob for logprob, _ in logprobs]) / len(logprobs)
results.append((avg_loglikelihood, True))
return results
def _model_call(self, inps): def _model_call(self, inps):
...@@ -100,9 +115,8 @@ class GGMLLM(BaseLM): ...@@ -100,9 +115,8 @@ class GGMLLM(BaseLM):
raise NotImplementedError() raise NotImplementedError()
@property @property
def max_length(self): def max_length(self):
# Placeholder implementation return 1024
raise NotImplementedError()
@property @property
def max_gen_toks(self): def max_gen_toks(self):
......
...@@ -62,5 +62,18 @@ class GGMLLMTest(unittest.TestCase): ...@@ -62,5 +62,18 @@ class GGMLLMTest(unittest.TestCase):
expected_res = ["generated text until stop1", "generated text until stop1"] expected_res = ["generated text until stop1", "generated text until stop1"]
self.assertEqual(res, expected_res) self.assertEqual(res, expected_res)
@patch('lm_eval.models.ggml.ggml_completion', side_effect=ggml_completion_mock)
def test_loglikelihood_rolling(self, ggml_completion_mock):
lm = GGMLLM(base_url)
# Test loglikelihood_rolling
requests = ["input1", "input2"]
res = lm.loglikelihood_rolling(requests)
# Assert the loglikelihood_rolling response is correct
expected_res = [(-1.2345, True), (-1.2345, True)]
self.assertEqual(res, expected_res)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment