Commit 85d3b8d1 authored by Matt Hoffner's avatar Matt Hoffner
Browse files

remove reordering to get cli running

parent 0fa2c7df
......@@ -20,7 +20,7 @@ class GGMLLM(BaseLM):
self.max_length = 1024
self.vocab_size = self.tokenizer.vocab_size
def ggml_completion(self, base_url, context, continuation=None, stop=None, retries=3, delay=5, **kwargs):
def ggml_completion(self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs):
for _ in range(retries):
try:
prompt = context
......@@ -29,7 +29,7 @@ class GGMLLM(BaseLM):
request = {'prompt': prompt, 'logprobs': self.logpobs}
if stop is not None:
request['stop'] = stop
response = requests.post(f"{base_url}/v1/completions", json=request)
response = requests.post(f"{self.base_url}/v1/completions", json=request)
response.raise_for_status()
return response.json()
except RequestException as e:
......@@ -40,9 +40,8 @@ class GGMLLM(BaseLM):
def loglikelihood(self, requests):
reorderer = Reorderer(requests, len)
requests = reorderer.get_reordered()
if not requests:
return []
res = []
for context, continuation in tqdm(requests):
response = self.ggml_completion(self.base_url, context=context, continuation=continuation)
......@@ -58,16 +57,13 @@ class GGMLLM(BaseLM):
else:
logger.error(f"Invalid response for loglikelihood. Response: {response}")
assert False
return reorderer.get_original(res)
return res
def greedy_until(self, requests):
if not requests:
return []
reorderer = Reorderer(requests, len)
requests = reorderer.get_reordered()
res = []
for request in tqdm(requests):
inp = request[0]
......@@ -85,7 +81,7 @@ class GGMLLM(BaseLM):
else:
logger.error(f"Invalid response for greedy_until. Response: {response}")
res.append(None) # Add default value in case of error
return reorderer.get_original(res)
return res
def loglikelihood_rolling(self, requests):
results = []
......
......@@ -59,7 +59,7 @@ class GGMLLMTest(unittest.TestCase):
res = lm.greedy_until(requests)
# Assert the greedy_until response is correct
expected_res = ["generated text until stop1", "generated text until stop1"]
expected_res = ["generated text until stop1", "generated text until stop2"]
self.assertEqual(res, expected_res)
@patch('lm_eval.models.ggml.GGMLLM.ggml_completion', side_effect=ggml_completion_mock)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment