Commit 91868979 authored by baberabb's avatar baberabb
Browse files

Merge remote-tracking branch 'origin/big-refactor' into big-refactor_vllm

# Conflicts:
#	lm_eval/models/__init__.py
parents 0635af13 815cd29d
......@@ -146,7 +146,7 @@ A full accounting of the supported and planned libraries + APIs can be seen belo
| GooseAI | :heavy_check_mark: (not separately maintained) | `openai`, `openai-completions`, `gooseai` (same interface as OpenAI Completions) | | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Textsynth | Needs testing | `textsynth` | ??? | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Cohere | :hourglass: - blocked on Cohere API bug | N/A | [All `cohere.generate()` engines](https://docs.cohere.com/docs/models) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| GGML | :hourglass: [PR](https://github.com/EleutherAI/lm-evaluation-harness/pull/617) | N/A | ??? | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| GGML/[Llama.cpp](https://github.com/ggerganov/llama.cpp) (via [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)) | :heavy_check_mark: | `gguf`, `ggml` | Llama-architecture models (Llama, Llama 2, Llemma, Mistral(?), Llama finetunes) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| vLLM | :x: Not yet - needs help! | N/A | All HF models | `generate_until` (no logprobs) |
| Your inference server here! | ... | ... | ... | ... | | ... |
......
......@@ -273,6 +273,24 @@ to the top of any Python file that is run or imported when performing evaluation
Passing `--tasks /path/to/yaml/file` is also accepted.
## Beautifying Table Display
To avoid conflict, each task needs to be registered with a unique name. Because of this, slight variations of task are still counted as unique tasks and need to be named uniquely. This could be done by appending an additional naming that may refer to the variation such as in MMLU where the template used to evaluated for flan are differentiated from the default by the prefix `mmlu_flan_*`. Printing the full task names can easily clutter the results table at the end of the evaluation especially when you have a long list of tasks or are using a benchmark that comprises of many tasks. To make it more legible, you can use `task_alias` and `group_alias` to provide an alternative task name and group name that will be printed.
``
for example in `mmlu_abstract_algebra.yaml` we set `group_alias` to `stem` and `task_alias` to `abstract_algebra`.
```
"dataset_name": "abstract_algebra"
"description": "The following are multiple choice questions (with answers) about abstract\
\ algebra.\n\n"
"group": "mmlu_stem"
"group_alias": "stem"
"include": "_default_template_yaml"
"task": "mmlu_abstract_algebra"
"task_alias": "abstract_algebra"
```
Note: Even though `group` can be a list, for now, `group_alias` can only be a single string.
## Checking validity
After registering your task, you can now check on your data downloading and verify that the few-shot samples look as intended. Run the following command with your desired args:
......
......@@ -3,6 +3,7 @@ from . import openai_completions
from . import textsynth
from . import dummy
from . import anthropic_llms
from . import gguf
from . import vllm_causallms
......
import requests
import logging
import time
from tqdm import tqdm
from requests.exceptions import RequestException
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
logger = logging.getLogger(__name__)
def get_result(logprobs, context_length):
is_greedy = True
offsets = logprobs["text_offset"]
tokens = logprobs["tokens"]
tokens_logprobs = logprobs["token_logprobs"]
idx = 0
while offsets[idx] < context_length:
idx += 1
continuation_logprobs = sum(tokens_logprobs[idx:-1])
for i in range(idx, len(tokens)):
token = tokens[i]
top_tokens = logprobs["top_logprobs"][i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
@register_model("gguf", "ggml")
class GGUFLM(LM):
def __init__(self, base_url=None, max_length=2048, **kwargs):
super().__init__()
self.base_url = base_url
assert self.base_url, "must pass `base_url` to use GGUF LM!"
self.logprobs = 10
self.temperature = 0.0
self.max_length = max_length
def gguf_completion(
self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs
):
for _ in range(retries):
try:
prompt = context
request = {
"prompt": prompt,
"logprobs": self.logprobs,
"temperature": self.temperature,
}
if continuation:
prompt += continuation
request.update({"prompt": prompt, "max_tokens": 1, "echo": True})
if stop is not None:
request["stop"] = stop
response = requests.post(
f"{self.base_url}/v1/completions", json=request
)
response.raise_for_status()
return response.json()
except RequestException as e:
logger.error(f"RequestException: {e}")
time.sleep(delay) # wait before retrying
else:
raise Exception(f"Failed to get a valid response after {retries} retries.")
def loglikelihood(self, requests):
if not requests:
return []
res = []
for context, continuation in tqdm([req.args for req in requests]):
response = self.gguf_completion(context=context, continuation=continuation)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
logprobs = choice.get("logprobs")
if (
logprobs
and "token_logprobs" in logprobs
and logprobs["token_logprobs"]
):
logprob, is_greedy = get_result(logprobs, len(context))
res.append((logprob, is_greedy))
else:
logger.warning(
"Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list."
)
else:
logger.error(
f"Invalid response for loglikelihood. Response: {response}"
)
assert False
return res
def generate_until(self, requests):
if not requests:
return []
res = []
for request in tqdm([req.args for req in requests]):
inp = request[0]
request_args = request[1]
until = request_args.get("until", ["</s>"])
response = self.gguf_completion(context=inp, stop=until)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
if "text" in choice:
generated_text = choice["text"].strip()
res.append(generated_text)
else:
logger.error(
f"Invalid response for greedy_until. Response: {response}"
)
res.append(None) # Add default value in case of error
else:
logger.error(f"Invalid response for greedy_until. Response: {response}")
res.append(None) # Add default value in case of error
return res
def loglikelihood_rolling(self, requests):
raise NotImplementedError(
"loglikelihood_rolling not yet supported for GGUF models"
)
import unittest
from unittest.mock import patch
import hashlib
import json
import os
import pickle
from lm_eval.models.gguf import GGUFLM
from lm_eval.api.instance import Instance
base_url = "https://matthoffner-ggml-llm-api.hf.space"
def gguf_completion_mock(base_url=None, **kwargs):
# Generate a hash from the parameters
hash_kwargs = {"base_url": base_url, **kwargs}
hash = hashlib.sha256(
json.dumps(hash_kwargs, sort_keys=True).encode("utf-8")
).hexdigest()
fname = f"./tests/testdata/gguf_test_{hash}.pkl"
if os.path.exists(fname):
with open(fname, "rb") as fh:
return pickle.load(fh)
else:
print("The file does not exist, attempting to write...")
if "stop" in kwargs:
result = {
"choices": [
{
"text": f"generated text until {kwargs['stop']}",
"logprobs": {"token_logprobs": [-1.2345], "text_offset": 0},
"finish_reason": "length",
}
]
}
else:
# generated with # curl -X 'POST' 'http://localhost:8000/v1/completions' -H 'accept: application/json' -H 'Content-Type: application/json' -d '{"prompt": "string", "logprobs": 10, "temperature": 0.0, "max_tokens": 1, "echo": true}'
result = {
"id": "cmpl-4023976b-bc6a-43b0-a5a9-629f4216c7f3",
"object": "text_completion",
"created": 1700511361,
"model": "../llama-2-7b.Q8_0.gguf",
"choices": [
{
"text": "string(",
"index": 0,
"logprobs": {
"text_offset": [0, 7],
"token_logprobs": [None, -1.033263319857306],
"tokens": [" string", "("],
"top_logprobs": [
None,
{
"(": -1.033263319857306,
"[]": -2.6530743779017394,
".": -3.0377145947291324,
"\n": -3.0399156750513976,
"_": -3.510376089937872,
" =": -3.6957918347193663,
",": -3.9309459866358702,
" of": -4.2834550083949035,
'("': -4.322762841112799,
"()": -4.426229113466925,
},
],
},
"finish_reason": "length",
}
],
"usage": {
"prompt_tokens": 2,
"completion_tokens": 1,
"total_tokens": 3,
},
}
try:
os.makedirs(os.path.dirname(fname), exist_ok=True)
print("Writing file at", fname)
with open(fname, "wb") as fh:
pickle.dump(result, fh)
print("File written successfully")
except Exception as e:
print("File writing failed:", e)
return result
class GGUFLMTest(unittest.TestCase):
@patch(
"lm_eval.models.gguf.GGUFLM.gguf_completion", side_effect=gguf_completion_mock
)
def test_loglikelihood(self, gguf_completion_mock):
lm = GGUFLM(base_url)
# Test loglikelihood
requests = [
Instance(
request_type="loglikelihood",
doc=args,
arguments=args,
idx=i,
)
for i, args in enumerate([("str", "ing"), ("str", "ing")])
]
res = lm.loglikelihood(requests)
# Assert the loglikelihood response is correct
expected_res = [(logprob, True) for logprob in [0, 0]]
self.assertEqual(res, expected_res)
@patch(
"lm_eval.models.gguf.GGUFLM.gguf_completion", side_effect=gguf_completion_mock
)
def test_generate_until(self, gguf_completion_mock):
lm = GGUFLM(base_url)
# Test generate_until
requests = [
Instance(
request_type="generate_until",
doc={"input": doc},
arguments=(doc, {"until": stop}),
idx=i,
)
for i, (doc, stop) in enumerate([("input1", "stop1"), ("input2", "stop2")])
]
res = lm.generate_until(requests)
# Assert the generate_until response is correct
expected_res = ["generated text until stop1", "generated text until stop2"]
self.assertEqual(res, expected_res)
# @patch('lm_eval.models.gguf.GGUFLM.gguf_completion', side_effect=gguf_completion_mock)
# def test_loglikelihood_rolling(self, gguf_completion_mock):
# lm = GGUFLM(base_url)
# # Test loglikelihood_rolling
# requests = ["input1", "input2"]
# res = lm.loglikelihood_rolling(requests)
# # Assert the loglikelihood_rolling response is correct
# expected_res = [(-1.2345, True), (-1.2345, True)]
# self.assertEqual(res, expected_res)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment