Commit 97d31082 authored by Matt Hoffner's avatar Matt Hoffner
Browse files

updates from feedback

parent 896bd5f9
import requests import requests
import logging import logging
import time
from lm_eval.utils import Reorderer
from lm_eval.base import BaseLM from lm_eval.base import BaseLM
from tqdm import tqdm from tqdm import tqdm
from requests.exceptions import RequestException from requests.exceptions import RequestException
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def ggml_completion(base_url, **kwargs): def ggml_completion(base_url, retries=3, delay=5, **kwargs):
try: for _ in range(retries):
response = requests.post(f"{base_url}/v1/completions", json=kwargs) try:
response.raise_for_status() response = requests.post(f"{base_url}/v1/completions", json=kwargs)
return response.json() response.raise_for_status()
except RequestException as e: return response.json()
print(f"RequestException: {e}") except RequestException as e:
return None logger.error(f"RequestException: {e}")
time.sleep(delay) # wait before retrying
else:
raise Exception(f"Failed to get a valid response after {retries} retries. Last exception: {e}")
class GGMLLM(BaseLM): class GGMLLM(BaseLM):
def __init__(self, base_url, truncate=False): def __init__(self, base_url, truncate=False):
...@@ -22,37 +27,54 @@ class GGMLLM(BaseLM): ...@@ -22,37 +27,54 @@ class GGMLLM(BaseLM):
self.truncate = truncate self.truncate = truncate
def loglikelihood(self, requests): def loglikelihood(self, requests):
reorderer = Reorderer(requests, len)
requests = reorderer.get_reordered()
res = [] res = []
for context, continuation in tqdm(requests): for context, continuation in tqdm(requests):
response = ggml_completion(self.base_url, context=context, continuation=continuation) response = ggml_completion(self.base_url, context=context, continuation=continuation)
if response and "choices" in response and response["choices"]: if response and "choices" in response and response["choices"]:
choice = response["choices"][0] choice = response["choices"][0]
logprobs = choice.get("logprobs") logprobs = choice.get("logprobs")
logprob = logprobs["token_logprobs"][0] if logprobs and logprobs["token_logprobs"] else -1.2345 try:
logprob = logprobs["token_logprobs"][0]
except TypeError:
raise ValueError("Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list.")
is_greedy = choice["finish_reason"] == "length" is_greedy = choice["finish_reason"] == "length"
res.append((logprob, is_greedy)) res.append((logprob, is_greedy))
else: else:
logger.error(f"Invalid response for loglikelihood. Response: {response}") logger.error(f"Invalid response for loglikelihood. Response: {response}")
assert False assert False
return res return reorderer.get_original(res)
def greedy_until(self, requests): def greedy_until(self, requests):
if not requests: if not requests:
return [] return []
reorderer = Reorderer(requests, len)
requests = reorderer.get_reordered()
res = [] res = []
for request in tqdm(requests): for request in tqdm(requests):
inp = request[0] inp = request[0]
request_args = request[1] request_args = request[1]
until = request_args["until"] until = request_args["until"]
response = ggml_completion(self.base_url, context=inp, stop=until) response = ggml_completion(self.base_url, context=inp, stop=until)
if response and "text" in response: print(response);
generated_text = response["text"].strip() if response and "choices" in response and response["choices"]:
res.append(generated_text) choice = response["choices"][0]
if "text" in choice:
generated_text = choice["text"].strip()
res.append(generated_text)
else:
logger.error(f"Invalid response for greedy_until. Response: {response}")
res.append(None) # Add default value in case of error
else: else:
logger.error(f"Invalid response for greedy_until. Response: {response}") logger.error(f"Invalid response for greedy_until. Response: {response}")
continue res.append(None) # Add default value in case of error
return res return reorderer.get_original(res)
def _model_call(self, inps): def _model_call(self, inps):
# Placeholder implementation # Placeholder implementation
......
...@@ -21,7 +21,7 @@ def ggml_completion_mock(base_url, **kwargs): ...@@ -21,7 +21,7 @@ def ggml_completion_mock(base_url, **kwargs):
else: else:
print("The file does not exist, attempting to write...") print("The file does not exist, attempting to write...")
if 'stop' in kwargs: if 'stop' in kwargs:
result = {"choices": [{"logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]} result = {"choices": [{"text": f"generated text until {kwargs['stop']}", "logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]}
else: else:
result = {"choices": [{"logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]} result = {"choices": [{"logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]}
...@@ -36,13 +36,12 @@ def ggml_completion_mock(base_url, **kwargs): ...@@ -36,13 +36,12 @@ def ggml_completion_mock(base_url, **kwargs):
return result return result
class GGMLLMTest(unittest.TestCase): class GGMLLMTest(unittest.TestCase):
@patch('lm_eval.models.ggml.ggml_completion', new=ggml_completion_mock) @patch('lm_eval.models.ggml.ggml_completion', side_effect=ggml_completion_mock)
def test_loglikelihood(self): def test_loglikelihood(self, ggml_completion_mock):
lm = GGMLLM(base_url) lm = GGMLLM(base_url)
lm.ggml_completion = ggml_completion_mock
# Test loglikelihood # Test loglikelihood
requests = [("context1", "continuation1"), ("context2", "continuation2")] requests = [("context1", "continuation1"), ("context2", "continuation2")]
res = lm.loglikelihood(requests) res = lm.loglikelihood(requests)
...@@ -51,19 +50,16 @@ class GGMLLMTest(unittest.TestCase): ...@@ -51,19 +50,16 @@ class GGMLLMTest(unittest.TestCase):
expected_res = [(logprob, True) for logprob in [-1.2345, -1.2345]] expected_res = [(logprob, True) for logprob in [-1.2345, -1.2345]]
self.assertEqual(res, expected_res) self.assertEqual(res, expected_res)
@patch('lm_eval.models.ggml.ggml_completion', new=ggml_completion_mock) @patch('lm_eval.models.ggml.ggml_completion', side_effect=ggml_completion_mock)
def test_greedy_until(self): def test_greedy_until(self, ggml_completion_mock):
lm = GGMLLM(base_url) lm = GGMLLM(base_url)
# Set the ggml_completion method to the defined mock
lm.ggml_completion = ggml_completion_mock
# Test greedy_until # Test greedy_until
requests = [("input1", {"until": "stop1"}), ("input2", {"until": "stop2"})] requests = [("input1", {"until": "stop1"}), ("input2", {"until": "stop2"})]
res = lm.greedy_until(requests) res = lm.greedy_until(requests)
# Assert the greedy_until response is correct # Assert the greedy_until response is correct
expected_res = [] expected_res = ["generated text until stop1", "generated text until stop1"]
self.assertEqual(res, expected_res) self.assertEqual(res, expected_res)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment