Commit 97d31082 authored by Matt Hoffner's avatar Matt Hoffner
Browse files

updates from feedback

parent 896bd5f9
import requests
import logging
import time
from lm_eval.utils import Reorderer
from lm_eval.base import BaseLM
from tqdm import tqdm
from requests.exceptions import RequestException
logger = logging.getLogger(__name__)
def ggml_completion(base_url, **kwargs):
try:
response = requests.post(f"{base_url}/v1/completions", json=kwargs)
response.raise_for_status()
return response.json()
except RequestException as e:
print(f"RequestException: {e}")
return None
def ggml_completion(base_url, retries=3, delay=5, **kwargs):
for _ in range(retries):
try:
response = requests.post(f"{base_url}/v1/completions", json=kwargs)
response.raise_for_status()
return response.json()
except RequestException as e:
logger.error(f"RequestException: {e}")
time.sleep(delay) # wait before retrying
else:
raise Exception(f"Failed to get a valid response after {retries} retries. Last exception: {e}")
class GGMLLM(BaseLM):
def __init__(self, base_url, truncate=False):
......@@ -22,37 +27,54 @@ class GGMLLM(BaseLM):
self.truncate = truncate
def loglikelihood(self, requests):
reorderer = Reorderer(requests, len)
requests = reorderer.get_reordered()
res = []
for context, continuation in tqdm(requests):
response = ggml_completion(self.base_url, context=context, continuation=continuation)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
logprobs = choice.get("logprobs")
logprob = logprobs["token_logprobs"][0] if logprobs and logprobs["token_logprobs"] else -1.2345
try:
logprob = logprobs["token_logprobs"][0]
except TypeError:
raise ValueError("Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list.")
is_greedy = choice["finish_reason"] == "length"
res.append((logprob, is_greedy))
else:
logger.error(f"Invalid response for loglikelihood. Response: {response}")
assert False
return res
return reorderer.get_original(res)
def greedy_until(self, requests):
if not requests:
return []
reorderer = Reorderer(requests, len)
requests = reorderer.get_reordered()
res = []
for request in tqdm(requests):
inp = request[0]
request_args = request[1]
until = request_args["until"]
response = ggml_completion(self.base_url, context=inp, stop=until)
if response and "text" in response:
generated_text = response["text"].strip()
res.append(generated_text)
print(response);
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
if "text" in choice:
generated_text = choice["text"].strip()
res.append(generated_text)
else:
logger.error(f"Invalid response for greedy_until. Response: {response}")
res.append(None) # Add default value in case of error
else:
logger.error(f"Invalid response for greedy_until. Response: {response}")
continue
return res
res.append(None) # Add default value in case of error
return reorderer.get_original(res)
def _model_call(self, inps):
# Placeholder implementation
......
......@@ -21,7 +21,7 @@ def ggml_completion_mock(base_url, **kwargs):
else:
print("The file does not exist, attempting to write...")
if 'stop' in kwargs:
result = {"choices": [{"logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]}
result = {"choices": [{"text": f"generated text until {kwargs['stop']}", "logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]}
else:
result = {"choices": [{"logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]}
......@@ -36,13 +36,12 @@ def ggml_completion_mock(base_url, **kwargs):
return result
class GGMLLMTest(unittest.TestCase):
@patch('lm_eval.models.ggml.ggml_completion', new=ggml_completion_mock)
def test_loglikelihood(self):
@patch('lm_eval.models.ggml.ggml_completion', side_effect=ggml_completion_mock)
def test_loglikelihood(self, ggml_completion_mock):
lm = GGMLLM(base_url)
lm.ggml_completion = ggml_completion_mock
# Test loglikelihood
requests = [("context1", "continuation1"), ("context2", "continuation2")]
res = lm.loglikelihood(requests)
......@@ -51,19 +50,16 @@ class GGMLLMTest(unittest.TestCase):
expected_res = [(logprob, True) for logprob in [-1.2345, -1.2345]]
self.assertEqual(res, expected_res)
@patch('lm_eval.models.ggml.ggml_completion', new=ggml_completion_mock)
def test_greedy_until(self):
@patch('lm_eval.models.ggml.ggml_completion', side_effect=ggml_completion_mock)
def test_greedy_until(self, ggml_completion_mock):
lm = GGMLLM(base_url)
# Set the ggml_completion method to the defined mock
lm.ggml_completion = ggml_completion_mock
# Test greedy_until
requests = [("input1", {"until": "stop1"}), ("input2", {"until": "stop2"})]
res = lm.greedy_until(requests)
# Assert the greedy_until response is correct
expected_res = []
expected_res = ["generated text until stop1", "generated text until stop1"]
self.assertEqual(res, expected_res)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment