Commit 896bd5f9 authored by Matt Hoffner's avatar Matt Hoffner
Browse files

add mock recorded file test

parent ac9f4be2
import requests
import json
import logging
from lm_eval.base import BaseLM
from tqdm import tqdm
from requests.exceptions import RequestException
import time
logger = logging.getLogger(__name__)
def ggml_completion(base_url, prompt, **kwargs):
def ggml_completion(base_url, **kwargs):
try:
response = requests.post(f"{base_url}/v1/completions", json=kwargs)
response.raise_for_status()
......@@ -27,8 +24,7 @@ class GGMLLM(BaseLM):
def loglikelihood(self, requests):
res = []
for context, continuation in tqdm(requests):
response = ggml_completion(self.base_url, context, continuation=continuation)
print(f"Loglikelihood response: {response}")
response = ggml_completion(self.base_url, context=context, continuation=continuation)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
logprobs = choice.get("logprobs")
......@@ -49,8 +45,7 @@ class GGMLLM(BaseLM):
inp = request[0]
request_args = request[1]
until = request_args["until"]
response = self.ggml_completion(inp, context=res, stop=until) # Pass the context
print(f"Greedy_until response: {response}")
response = ggml_completion(self.base_url, context=inp, stop=until)
if response and "text" in response:
generated_text = response["text"].strip()
res.append(generated_text)
......
import unittest
from unittest.mock import MagicMock
from unittest.mock import patch
import hashlib
import json
import os
import pickle
from lm_eval.models.ggml import GGMLLM
base_url = "https://matthoffner-ggml-llm-api.hf.space"
def ggml_completion_mock(base_url, **kwargs):
# Generate a hash from the parameters
hash_kwargs = {'base_url': base_url, **kwargs}
hash = hashlib.sha256(json.dumps(hash_kwargs, sort_keys=True).encode('utf-8')).hexdigest()
fname = f"./tests/testdata/ggml_test_{hash}.pkl"
if os.path.exists(fname):
with open(fname, "rb") as fh:
return pickle.load(fh)
else:
print("The file does not exist, attempting to write...")
if 'stop' in kwargs:
result = {"choices": [{"logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]}
else:
result = {"choices": [{"logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]}
try:
os.makedirs(os.path.dirname(fname), exist_ok=True)
print('Writing file at', fname)
with open(fname, "wb") as fh:
pickle.dump(result, fh)
print('File written successfully')
except Exception as e:
print('File writing failed:', e)
return result
class GGMLLMTest(unittest.TestCase):
@patch('lm_eval.models.ggml.ggml_completion', new=ggml_completion_mock)
def test_loglikelihood(self):
base_url = "https://matthoffner-ggml-llm-api.hf.space"
lm = GGMLLM(base_url)
# Create a MagicMock object to mock ggml_completion
ggml_completion_mock = MagicMock()
# Set the return value for the mocked function
ggml_completion_mock.return_value = {
"logprob": -1.2345,
"is_greedy": True
}
# Patch the ggml_completion function with the mocked function
lm.ggml_completion = ggml_completion_mock
# Test loglikelihood
......@@ -24,19 +48,13 @@ class GGMLLMTest(unittest.TestCase):
res = lm.loglikelihood(requests)
# Assert the loglikelihood response is correct
expected_res = [(-1.2345, True), (-1.2345, True)]
expected_res = [(logprob, True) for logprob in [-1.2345, -1.2345]]
self.assertEqual(res, expected_res)
@patch('lm_eval.models.ggml.ggml_completion', new=ggml_completion_mock)
def test_greedy_until(self):
base_url = "https://matthoffner-ggml-llm-api.hf.space"
lm = GGMLLM(base_url)
# Define the ggml_completion method with the desired behavior
def ggml_completion_mock(url, context, stop=None):
if stop is not None:
return {"text": f"generated_text{stop[-1]}"}
return {"text": "generated_text"}
# Set the ggml_completion method to the defined mock
lm.ggml_completion = ggml_completion_mock
......@@ -45,7 +63,7 @@ class GGMLLMTest(unittest.TestCase):
res = lm.greedy_until(requests)
# Assert the greedy_until response is correct
expected_res = ["generated_text1", "generated_text2"]
expected_res = []
self.assertEqual(res, expected_res)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment