Commit 0fa2c7df authored by Matt Hoffner's avatar Matt Hoffner
Browse files

updates to verify end to end with example

parent b56dee4e
......@@ -141,6 +141,15 @@ python main.py \
--tasks hellaswag
```
GGML quantized models can be loaded by using `llama-cpp-python` server:
```bash
python main.py \
--model ggml \
--model_args base_url=http://localhost:8000 \
--tasks hellaswag
```
We support wildcards in task names, for example you can run all of the machine-translated lambada tasks via `--task lambada_openai_mt_*`.
We currently only support one prompt per task, which we strive to make the "standard" as defined by the benchmark's authors. If you would like to study how varying prompts causes changes in the evaluation score, check out the [BigScience fork](https://github.com/bigscience-workshop/lm-evaluation-harness) of this repo. We are currently working on upstreaming this capability to `main`.
......
......@@ -3,29 +3,41 @@ import logging
import time
from tqdm import tqdm
from requests.exceptions import RequestException
import transformers
from lm_eval.utils import Reorderer
from lm_eval.base import BaseLM
logger = logging.getLogger(__name__)
def ggml_completion(base_url, retries=3, delay=5, **kwargs):
for _ in range(retries):
try:
response = requests.post(f"{base_url}/v1/completions", json=kwargs)
response.raise_for_status()
return response.json()
except RequestException as e:
logger.error(f"RequestException: {e}")
time.sleep(delay) # wait before retrying
else:
raise Exception(f"Failed to get a valid response after {retries} retries. Last exception: {e}")
class GGMLLM(BaseLM):
def __init__(self, base_url, truncate=False):
super().__init__()
self.base_url = base_url
self.truncate = truncate
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.logpobs = 10
self.max_length = 1024
self.vocab_size = self.tokenizer.vocab_size
def ggml_completion(self, base_url, context, continuation=None, stop=None, retries=3, delay=5, **kwargs):
for _ in range(retries):
try:
prompt = context
if continuation:
prompt += continuation
request = {'prompt': prompt, 'logprobs': self.logpobs}
if stop is not None:
request['stop'] = stop
response = requests.post(f"{base_url}/v1/completions", json=request)
response.raise_for_status()
return response.json()
except RequestException as e:
logger.error(f"RequestException: {e}")
time.sleep(delay) # wait before retrying
else:
raise Exception(f"Failed to get a valid response after {retries} retries. Last exception: {e}")
def loglikelihood(self, requests):
reorderer = Reorderer(requests, len)
......@@ -33,20 +45,21 @@ class GGMLLM(BaseLM):
res = []
for context, continuation in tqdm(requests):
response = ggml_completion(self.base_url, context=context, continuation=continuation)
response = self.ggml_completion(self.base_url, context=context, continuation=continuation)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
logprobs = choice.get("logprobs")
try:
if logprobs and "token_logprobs" in logprobs:
logprob = logprobs["token_logprobs"][0]
except TypeError:
raise ValueError("Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list.")
is_greedy = choice["finish_reason"] == "length"
res.append((logprob, is_greedy))
is_greedy = choice["finish_reason"] == "length"
res.append((logprob, is_greedy))
else:
logger.warning("Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list.")
else:
logger.error(f"Invalid response for loglikelihood. Response: {response}")
assert False
return reorderer.get_original(res)
def greedy_until(self, requests):
if not requests:
......@@ -60,8 +73,7 @@ class GGMLLM(BaseLM):
inp = request[0]
request_args = request[1]
until = request_args["until"]
response = ggml_completion(self.base_url, context=inp, stop=until)
print(response);
response = self.ggml_completion(self.base_url, context=inp, stop=until)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
if "text" in choice:
......@@ -99,6 +111,12 @@ class GGMLLM(BaseLM):
# Placeholder implementation
raise NotImplementedError()
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
@property
def batch_size(self):
# Placeholder implementation
......@@ -114,19 +132,10 @@ class GGMLLM(BaseLM):
# Placeholder implementation
raise NotImplementedError()
@property
def max_length(self):
return 1024
return self.max_length
@property
def max_gen_toks(self):
# Placeholder implementation
raise NotImplementedError()
def tok_encode(self, string: str):
# Placeholder implementation
raise NotImplementedError()
def tok_decode(self, tokens):
# Placeholder implementation
raise NotImplementedError()
raise NotImplementedError()
\ No newline at end of file
......@@ -38,7 +38,7 @@ def ggml_completion_mock(base_url, **kwargs):
class GGMLLMTest(unittest.TestCase):
@patch('lm_eval.models.ggml.ggml_completion', side_effect=ggml_completion_mock)
@patch('lm_eval.models.ggml.GGMLLM.ggml_completion', side_effect=ggml_completion_mock)
def test_loglikelihood(self, ggml_completion_mock):
lm = GGMLLM(base_url)
......@@ -50,7 +50,7 @@ class GGMLLMTest(unittest.TestCase):
expected_res = [(logprob, True) for logprob in [-1.2345, -1.2345]]
self.assertEqual(res, expected_res)
@patch('lm_eval.models.ggml.ggml_completion', side_effect=ggml_completion_mock)
@patch('lm_eval.models.ggml.GGMLLM.ggml_completion', side_effect=ggml_completion_mock)
def test_greedy_until(self, ggml_completion_mock):
lm = GGMLLM(base_url)
......@@ -62,7 +62,7 @@ class GGMLLMTest(unittest.TestCase):
expected_res = ["generated text until stop1", "generated text until stop1"]
self.assertEqual(res, expected_res)
@patch('lm_eval.models.ggml.ggml_completion', side_effect=ggml_completion_mock)
@patch('lm_eval.models.ggml.GGMLLM.ggml_completion', side_effect=ggml_completion_mock)
def test_loglikelihood_rolling(self, ggml_completion_mock):
lm = GGMLLM(base_url)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment