Commit 85d3b8d1 authored by Matt Hoffner's avatar Matt Hoffner
Browse files

remove reordering to get cli running

parent 0fa2c7df
...@@ -20,7 +20,7 @@ class GGMLLM(BaseLM): ...@@ -20,7 +20,7 @@ class GGMLLM(BaseLM):
self.max_length = 1024 self.max_length = 1024
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
def ggml_completion(self, base_url, context, continuation=None, stop=None, retries=3, delay=5, **kwargs): def ggml_completion(self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs):
for _ in range(retries): for _ in range(retries):
try: try:
prompt = context prompt = context
...@@ -29,7 +29,7 @@ class GGMLLM(BaseLM): ...@@ -29,7 +29,7 @@ class GGMLLM(BaseLM):
request = {'prompt': prompt, 'logprobs': self.logpobs} request = {'prompt': prompt, 'logprobs': self.logpobs}
if stop is not None: if stop is not None:
request['stop'] = stop request['stop'] = stop
response = requests.post(f"{base_url}/v1/completions", json=request) response = requests.post(f"{self.base_url}/v1/completions", json=request)
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
except RequestException as e: except RequestException as e:
...@@ -40,9 +40,8 @@ class GGMLLM(BaseLM): ...@@ -40,9 +40,8 @@ class GGMLLM(BaseLM):
def loglikelihood(self, requests): def loglikelihood(self, requests):
reorderer = Reorderer(requests, len) if not requests:
requests = reorderer.get_reordered() return []
res = [] res = []
for context, continuation in tqdm(requests): for context, continuation in tqdm(requests):
response = self.ggml_completion(self.base_url, context=context, continuation=continuation) response = self.ggml_completion(self.base_url, context=context, continuation=continuation)
...@@ -58,16 +57,13 @@ class GGMLLM(BaseLM): ...@@ -58,16 +57,13 @@ class GGMLLM(BaseLM):
else: else:
logger.error(f"Invalid response for loglikelihood. Response: {response}") logger.error(f"Invalid response for loglikelihood. Response: {response}")
assert False assert False
return reorderer.get_original(res) return res
def greedy_until(self, requests): def greedy_until(self, requests):
if not requests: if not requests:
return [] return []
reorderer = Reorderer(requests, len)
requests = reorderer.get_reordered()
res = [] res = []
for request in tqdm(requests): for request in tqdm(requests):
inp = request[0] inp = request[0]
...@@ -85,7 +81,7 @@ class GGMLLM(BaseLM): ...@@ -85,7 +81,7 @@ class GGMLLM(BaseLM):
else: else:
logger.error(f"Invalid response for greedy_until. Response: {response}") logger.error(f"Invalid response for greedy_until. Response: {response}")
res.append(None) # Add default value in case of error res.append(None) # Add default value in case of error
return reorderer.get_original(res) return res
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests):
results = [] results = []
......
...@@ -59,7 +59,7 @@ class GGMLLMTest(unittest.TestCase): ...@@ -59,7 +59,7 @@ class GGMLLMTest(unittest.TestCase):
res = lm.greedy_until(requests) res = lm.greedy_until(requests)
# Assert the greedy_until response is correct # Assert the greedy_until response is correct
expected_res = ["generated text until stop1", "generated text until stop1"] expected_res = ["generated text until stop1", "generated text until stop2"]
self.assertEqual(res, expected_res) self.assertEqual(res, expected_res)
@patch('lm_eval.models.ggml.GGMLLM.ggml_completion', side_effect=ggml_completion_mock) @patch('lm_eval.models.ggml.GGMLLM.ggml_completion', side_effect=ggml_completion_mock)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment