Unverified Commit 8941c067 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Merge branch 'EleutherAI:master' into deepsparselm

parents ff24ddbd 008fc2a2
* @haileyschoelkopf @lintangsutawika * @haileyschoelkopf @lintangsutawika @StellaAthena
FROM nvidia/cuda:11.2.0-cudnn8-runtime-ubuntu20.04 FROM nvidia/cuda:11.2.2-cudnn8-runtime-ubuntu20.04
### Install python 3.10 and set it as default python interpreter ### Install python 3.10 and set it as default python interpreter
......
...@@ -141,6 +141,15 @@ python main.py \ ...@@ -141,6 +141,15 @@ python main.py \
--tasks hellaswag --tasks hellaswag
``` ```
GGUF or GGML quantized models can be loaded by using `llama-cpp-python` server:
```bash
python main.py \
--model gguf \
--model_args base_url=http://localhost:8000 \
--tasks hellaswag
```
We support wildcards in task names, for example you can run all of the machine-translated lambada tasks via `--task lambada_openai_mt_*`. We support wildcards in task names, for example you can run all of the machine-translated lambada tasks via `--task lambada_openai_mt_*`.
We currently only support one prompt per task, which we strive to make the "standard" as defined by the benchmark's authors. If you would like to study how varying prompts causes changes in the evaluation score, check out the [BigScience fork](https://github.com/bigscience-workshop/lm-evaluation-harness) of this repo. We are currently working on upstreaming this capability to `main`. We currently only support one prompt per task, which we strive to make the "standard" as defined by the benchmark's authors. If you would like to study how varying prompts causes changes in the evaluation score, check out the [BigScience fork](https://github.com/bigscience-workshop/lm-evaluation-harness) of this repo. We are currently working on upstreaming this capability to `main`.
......
...@@ -7,7 +7,6 @@ import os ...@@ -7,7 +7,6 @@ import os
import json import json
import hashlib import hashlib
import datasets import datasets
from sqlitedict import SqliteDict
from tqdm import tqdm from tqdm import tqdm
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -891,6 +890,7 @@ class CachingLM: ...@@ -891,6 +890,7 @@ class CachingLM:
:param cache_db: str :param cache_db: str
Path to cache db Path to cache db
""" """
from sqlitedict import SqliteDict
self.lm = lm self.lm = lm
self.cache_db = cache_db self.cache_db = cache_db
if os.path.dirname(cache_db): if os.path.dirname(cache_db):
......
...@@ -43,7 +43,7 @@ level (for indicating the level of difficulty). ...@@ -43,7 +43,7 @@ level (for indicating the level of difficulty).
_HOMEPAGE = "https://github.com/chaochun/nlu-asdiv-dataset" _HOMEPAGE = "https://github.com/chaochun/nlu-asdiv-dataset"
# License available at https://github.com/chaochun/nlu-asdiv-dataset/blob/master/README.md # License declared at https://github.com/chaochun/nlu-asdiv-dataset/blob/master/README.md
_LICENSE = "CC BY-NC 4.0" _LICENSE = "CC BY-NC 4.0"
_URLS = "https://github.com/chaochun/nlu-asdiv-dataset/archive/55790e5270bb91ccfa5053194b25732534696b50.zip" _URLS = "https://github.com/chaochun/nlu-asdiv-dataset/archive/55790e5270bb91ccfa5053194b25732534696b50.zip"
......
...@@ -43,7 +43,7 @@ and perform discrete operations over them (such as addition, counting, or sortin ...@@ -43,7 +43,7 @@ and perform discrete operations over them (such as addition, counting, or sortin
_HOMEPAGE = "https://allenai.org/data/drop" _HOMEPAGE = "https://allenai.org/data/drop"
# License available at https://allenai.org/data/drop # License declared at https://allenai.org/data/drop
_LICENSE = "CC BY" _LICENSE = "CC BY"
_URLS = { _URLS = {
......
...@@ -51,8 +51,10 @@ The dataset contains questions about the following topics: medicine, nursing, ps ...@@ -51,8 +51,10 @@ The dataset contains questions about the following topics: medicine, nursing, ps
_HOMEPAGE = "https://aghie.github.io/head-qa/" _HOMEPAGE = "https://aghie.github.io/head-qa/"
# License available at https://github.com/aghie/head-qa/blob/master/LICENSE # The Spanish data comes from the "Ministerio de Sanidad, Consumo y Bienestar Social", as indicated here : https://github.com/aghie/head-qa
_LICENSE = "MIT License" # This Spanish data seems to follow the intellectual property rights stated here : https://www.sanidad.gob.es/avisoLegal/home.htm
# The English data was translated by the authors of head-qa (https://arxiv.org/pdf/1906.04701.pdf).
_LICENSE = "Custom license"
_URL = "https://drive.google.com/uc?export=download&confirm=t&id=1a_95N5zQQoUCq8IBNVZgziHbeM-QxG2t" _URL = "https://drive.google.com/uc?export=download&confirm=t&id=1a_95N5zQQoUCq8IBNVZgziHbeM-QxG2t"
......
...@@ -41,8 +41,10 @@ learning agents. ...@@ -41,8 +41,10 @@ learning agents.
_HOMEPAGE = "https://github.com/hendrycks/ethics" _HOMEPAGE = "https://github.com/hendrycks/ethics"
# License available at https://github.com/hendrycks/ethics/blob/master/LICENSE # The authors declared that the dataset is not distributed under a copyright or intellectual property (https://arxiv.org/pdf/2008.02275.pdf)
_LICENSE = "MIT License" # On Hugging Face, the dataset is distributed under the MIT license (https://huggingface.co/datasets/hendrycks/ethics)
# The common sense portion is from Reddit and might incur some licensing complications.
_LICENSE = "Ambiguous"
_URLS = "https://people.eecs.berkeley.edu/~hendrycks/ethics.tar" _URLS = "https://people.eecs.berkeley.edu/~hendrycks/ethics.tar"
......
...@@ -38,7 +38,7 @@ models to generate answer derivations and explanations. ...@@ -38,7 +38,7 @@ models to generate answer derivations and explanations.
_HOMEPAGE = "https://github.com/hendrycks/math" _HOMEPAGE = "https://github.com/hendrycks/math"
# License available at https://github.com/hendrycks/math/blob/main/LICENSE # License declared at https://arxiv.org/pdf/2103.03874.pdf
_LICENSE = "MIT License" _LICENSE = "MIT License"
_URLS = "https://people.eecs.berkeley.edu/~hendrycks/MATH.tar" _URLS = "https://people.eecs.berkeley.edu/~hendrycks/MATH.tar"
......
...@@ -38,8 +38,8 @@ math, computer science, and philosophy papers. ...@@ -38,8 +38,8 @@ math, computer science, and philosophy papers.
_HOMEPAGE = "https://pile.eleuther.ai/" _HOMEPAGE = "https://pile.eleuther.ai/"
# License available at https://github.com/EleutherAI/the-pile/blob/master/LICENSE # More details at https://arxiv.org/pdf/2201.07311.pdf
_LICENSE = "MIT License" _LICENSE = "Multiple licenses"
_URLS = { _URLS = {
"validation": "https://the-eye.eu/public/AI/pile/val.jsonl.zst", "validation": "https://the-eye.eu/public/AI/pile/val.jsonl.zst",
......
...@@ -39,7 +39,7 @@ a teacher who answers the questions by providing short excerpts (spans) from the ...@@ -39,7 +39,7 @@ a teacher who answers the questions by providing short excerpts (spans) from the
_HOMEPAGE = "https://quac.ai/" _HOMEPAGE = "https://quac.ai/"
# License available at https://quac.ai/ # License declared at https://quac.ai/
_LICENSE = "CC BY-SA 4.0" _LICENSE = "CC BY-SA 4.0"
_URLS = { _URLS = {
......
...@@ -5,6 +5,7 @@ from . import huggingface ...@@ -5,6 +5,7 @@ from . import huggingface
from . import textsynth from . import textsynth
from . import deepsparse from . import deepsparse
from . import dummy from . import dummy
from . import gguf
MODEL_REGISTRY = { MODEL_REGISTRY = {
"hf": gpt2.HFLM, "hf": gpt2.HFLM,
...@@ -17,6 +18,7 @@ MODEL_REGISTRY = { ...@@ -17,6 +18,7 @@ MODEL_REGISTRY = {
"textsynth": textsynth.TextSynthLM, "textsynth": textsynth.TextSynthLM,
"deepsparse": deepsparse.DeepSparseLM, "deepsparse": deepsparse.DeepSparseLM,
"dummy": dummy.DummyLM, "dummy": dummy.DummyLM,
"gguf": gguf.GGUFLM
} }
......
import requests
import logging
import time
from tqdm import tqdm
from requests.exceptions import RequestException
import transformers
from lm_eval.utils import Reorderer
from lm_eval.base import BaseLM
logger = logging.getLogger(__name__)
def get_result(logprobs, context_length):
is_greedy = True
offsets = logprobs['text_offset']
tokens = logprobs['tokens']
tokens_logprobs = logprobs['token_logprobs']
idx = 0
while offsets[idx] < context_length:
idx += 1
continuation_logprobs = sum(tokens_logprobs[idx:-1])
for i in range(idx, len(tokens)):
token = tokens[i]
top_tokens = logprobs["top_logprobs"][i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
class GGUFLM(BaseLM):
def __init__(self, base_url, max_length=2048):
super().__init__()
self.base_url = base_url
self.logprobs = 10
self.temperature = 0.0
self.max_length = max_length
def gguf_completion(self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs):
for _ in range(retries):
try:
prompt = context
request = {'prompt': prompt, 'logprobs': self.logprobs,
'temperature': self.temperature}
if continuation:
prompt += continuation
request.update({'prompt': prompt, 'max_tokens': 1, 'echo': True})
if stop is not None:
request['stop'] = stop
response = requests.post(f"{self.base_url}/v1/completions", json=request)
response.raise_for_status()
return response.json()
except RequestException as e:
logger.error(f"RequestException: {e}")
time.sleep(delay) # wait before retrying
else:
raise Exception(f"Failed to get a valid response after {retries} retries.")
def loglikelihood(self, requests):
if not requests:
return []
res = []
for context, continuation in tqdm(requests):
response = self.gguf_completion(context=context, continuation=continuation)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
logprobs = choice.get("logprobs")
if logprobs and "token_logprobs" in logprobs and logprobs["token_logprobs"]:
logprob, is_greedy = get_result(logprobs, len(context))
res.append((logprob, is_greedy))
else:
logger.warning("Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list.")
else:
logger.error(f"Invalid response for loglikelihood. Response: {response}")
assert False
return res
def greedy_until(self, requests):
if not requests:
return []
res = []
for request in tqdm(requests):
inp = request[0]
request_args = request[1]
until = request_args["until"]
response = self.gguf_completion(context=inp, stop=until)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
if "text" in choice:
generated_text = choice["text"].strip()
res.append(generated_text)
else:
logger.error(f"Invalid response for greedy_until. Response: {response}")
res.append(None) # Add default value in case of error
else:
logger.error(f"Invalid response for greedy_until. Response: {response}")
res.append(None) # Add default value in case of error
return res
def loglikelihood_rolling(self, requests):
raise NotImplementedError("loglikelihood_rolling not yet supported for GGUF models")
def _model_call(self, inps):
# Placeholder implementation
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Placeholder implementation
raise NotImplementedError()
def tok_encode(self, string: str):
raise NotImplementedError()
def tok_decode(self, tokens):
raise NotImplementedError()
@property
def batch_size(self):
# Placeholder implementation
raise NotImplementedError()
@property
def device(self):
# Placeholder implementation
raise NotImplementedError()
@property
def eot_token_id(self):
# Placeholder implementation
raise NotImplementedError()
def max_length(self):
return self.max_length
@property
def max_gen_toks(self):
# Placeholder implementation
raise NotImplementedError()
...@@ -19,7 +19,6 @@ _DeviceMapping = NewType("DeviceMapping", Mapping[str, Union[int, str, torch.dev ...@@ -19,7 +19,6 @@ _DeviceMapping = NewType("DeviceMapping", Mapping[str, Union[int, str, torch.dev
def _get_accelerate_args( def _get_accelerate_args(
low_cpu_mem_usage: Optional[bool] = True,
device_map_option: Optional[str] = "auto", device_map_option: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None, max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None, max_cpu_memory: Optional[Union[int, str]] = None,
...@@ -39,7 +38,6 @@ def _get_accelerate_args( ...@@ -39,7 +38,6 @@ def _get_accelerate_args(
args = {} args = {}
if max_memory: if max_memory:
args["max_memory"] = max_memory args["max_memory"] = max_memory
args["low_cpu_mem_usage"] = low_cpu_mem_usage
args["device_map"] = device_map_option args["device_map"] = device_map_option
args["offload_folder"] = offload_folder args["offload_folder"] = offload_folder
return args return args
...@@ -222,7 +220,6 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -222,7 +220,6 @@ class HuggingFaceAutoLM(BaseLM):
model_kwargs = {} model_kwargs = {}
if use_accelerate: if use_accelerate:
model_kwargs = _get_accelerate_args( model_kwargs = _get_accelerate_args(
low_cpu_mem_usage,
device_map_option, device_map_option,
max_memory_per_gpu, max_memory_per_gpu,
max_cpu_memory, max_cpu_memory,
...@@ -242,6 +239,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -242,6 +239,7 @@ class HuggingFaceAutoLM(BaseLM):
bnb_4bit_quant_type=bnb_4bit_quant_type, bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
bnb_4bit_use_double_quant=bnb_4bit_use_double_quant, bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
low_cpu_mem_usage=low_cpu_mem_usage,
**model_kwargs, **model_kwargs,
) )
# note: peft_path can be different than pretrained model path # note: peft_path can be different than pretrained model path
......
...@@ -349,7 +349,7 @@ TASK_REGISTRY = { ...@@ -349,7 +349,7 @@ TASK_REGISTRY = {
**mgsm.construct_tasks(), **mgsm.construct_tasks(),
**scrolls.construct_tasks(), **scrolls.construct_tasks(),
**ceval.create_all_tasks(), **ceval.create_all_tasks(),
**cmmlu.create_all_tasks(), **cmmlu.create_all_tasks()
} }
......
...@@ -2,12 +2,19 @@ ...@@ -2,12 +2,19 @@
CMMLU: Measuring massive multitask language understanding in Chinese CMMLU: Measuring massive multitask language understanding in Chinese
https://arxiv.org/abs/2306.09212 https://arxiv.org/abs/2306.09212
CMMLU is a comprehensive evaluation benchmark specifically designed to evaluate the knowledge and reasoning abilities of LLMs within the context of Chinese language and culture. CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge
CMMLU covers a wide range of subjects, comprising 67 topics that span from elementary to advanced professional levels. and reasoning abilities of LLMs within the Chinese language and cultural context. CMMLU covers a wide range of
subjects, comprising 67 topics that span from elementary to advanced professional levels. It includes subjects that
require computational expertise, such as physics and mathematics, as well as disciplines within humanities and
social sciences. Many of these tasks are not easily translatable from other languages due to their specific
contextual nuances and wording. Furthermore, numerous tasks within CMMLU have answers that are specific to
China and may not be universally applicable or considered correct in other regions or languages.
Homepage: https://github.com/haonan-li/CMMLU Homepage: https://github.com/haonan-li/CMMLU
Huggingface homepage: https://huggingface.co/datasets/haonan-li/cmmlu
""" """
from lm_eval.base import MultipleChoiceTask import os
from lm_eval.base import MultipleChoiceTask, rf
_CITATION = """ _CITATION = """
@misc{li2023cmmlu, @misc{li2023cmmlu,
...@@ -21,7 +28,77 @@ _CITATION = """ ...@@ -21,7 +28,77 @@ _CITATION = """
""" """
SUBJECTS = { SUBJECTS = [
"agronomy",
"anatomy",
"ancient_chinese",
"arts",
"astronomy",
"business_ethics",
"chinese_civil_service_exam",
"chinese_driving_rule",
"chinese_food_culture",
"chinese_foreign_policy",
"chinese_history",
"chinese_literature",
"chinese_teacher_qualification",
"clinical_knowledge",
"college_actuarial_science",
"college_education",
"college_engineering_hydrology",
"college_law",
"college_mathematics",
"college_medical_statistics",
"college_medicine",
"computer_science",
"computer_security",
"conceptual_physics",
"construction_project_management",
"economics",
"education",
"electrical_engineering",
"elementary_chinese",
"elementary_commonsense",
"elementary_information_and_technology",
"elementary_mathematics",
"ethnology",
"food_science",
"genetics",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_geography",
"high_school_mathematics",
"high_school_physics",
"high_school_politics",
"human_sexuality",
"international_law",
"journalism",
"jurisprudence",
"legal_and_moral_basis",
"logical",
"machine_learning",
"management",
"marketing",
"marxist_theory",
"modern_chinese",
"nutrition",
"philosophy",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_study",
"sociology",
"sports_science",
"traditional_chinese_medicine",
"virology",
"world_history",
"world_religions"
]
SUBJECT_MAPPING = {
"agronomy": "农学", "agronomy": "农学",
"anatomy": "解剖学", "anatomy": "解剖学",
"ancient_chinese": "古汉语", "ancient_chinese": "古汉语",
...@@ -91,26 +168,103 @@ SUBJECTS = { ...@@ -91,26 +168,103 @@ SUBJECTS = {
"world_religions": "世界宗教", "world_religions": "世界宗教",
} }
SUBJECT_CATEGORIES = {
"agronomy": ['other'],
"anatomy": ['biology'],
"ancient_chinese": ['linguistics','china specific'],
"arts": ['arts'],
"astronomy": ['physics'],
"business_ethics": ['business'],
"chinese_civil_service_exam": ['politics','china specific'],
"chinese_driving_rule": ['other','china specific'],
"chinese_food_culture": ['culture','china specific'],
"chinese_foreign_policy": ['politics','china specific'],
"chinese_history":['history','china specific'],
"chinese_literature": ['literature','china specific'],
"chinese_teacher_qualification": ['education','china specific'],
"college_actuarial_science":['math'],
"college_education":['education'],
"college_engineering_hydrology": ['engineering'],
"college_law": ['law'],
"college_mathematics": ['math'],
"college_medical_statistics":['statistics'],
"clinical_knowledge": ['other'],
"college_medicine": ['other'],
"computer_science": ['computer science'],
"computer_security": ['other'],
"conceptual_physics": ['physics'],
"construction_project_management": ['other','china specific'],
"economics": ['economics'],
"education": ['education'],
"elementary_chinese":['linguistics','china specific'],
"elementary_commonsense":['other','china specific'],
"elementary_information_and_technology": ['other'],
"electrical_engineering": ['engineering'],
"elementary_mathematics": ['math'],
"ethnology": ['culture','china specific'],
"food_science": ['other'],
"genetics": ['biology'],
"global_facts": ['global'],
"high_school_biology": ['biology'],
"high_school_chemistry": ['chemistry'],
"high_school_geography": ['geography'],
"high_school_mathematics": ['math'],
"high_school_physics": ['physics'],
"high_school_politics": ['politics','china specific'],
"human_sexuality": ['other'],
"international_law": ['law'],
"journalism": ['sociology'],
"jurisprudence": ['law'],
"legal_and_moral_basis": ['other'],
"logical": ['philosophy'],
"machine_learning": ['computer science'],
"management": ['business'],
"marketing": ['business'],
"marxist_theory": ['philosophy'],
"modern_chinese": ['linguistics','china specific'],
"nutrition": ['other'],
"philosophy": ['philosophy'],
"professional_accounting": ['business'],
"professional_law": ['law'],
"professional_medicine": ['other'],
"professional_psychology": ['psychology'],
"public_relations": ['politics'],
"security_study": ['politics'],
"sociology": ['culture'],
"sports_science": ['other'],
"traditional_chinese_medicine": ['other','china specific'],
"virology": ['biology'],
"world_history":['history'],
"world_religions": ['global'],
}
CATEGORIES = {
"STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering", "statistics"],
"Humanities": ["history", "philosophy", "law", "arts", "literature", "global"],
"Social Science": ['linguistics',"business", "politics", "culture", "economics", "geography", "psychology", "education", "sociology"],
"Other":["other"],
"China specific": ["china specific"],
}
def create_all_tasks(): def create_all_tasks():
"""Creates a dictionary of tasks from a list of subjects """Creates a dictionary of tasks from a list of subjects
:return: {task_name: task} :return: {task_name: task}
e.g. {cmmlu-world_history: Task, cmmlu-virology: Task} e.g. {cmmlu-physician: Task, cmmlu-tax_accountant: Task}
""" """
return {f"cmmlu-{sub}": create_task(sub) for sub in SUBJECTS.keys()} return {f"cmmlu-{sub}": create_task(sub) for sub in SUBJECTS}
def create_task(subject): def create_task(subject):
class Cmmlu(CmmluSubject): class CmmluTest(GeneralCmmluTest):
def __init__(self): def __init__(self):
super().__init__(subject) super().__init__(subject)
return Cmmlu return CmmluTest
class CmmluSubject(MultipleChoiceTask): class GeneralCmmluTest(MultipleChoiceTask):
VERSION = 1 VERSION = 1
DATASET_PATH = "haonan-li/cmmlu" DATASET_PATH = os.path.join("haonan-li/cmmlu")
DATASET_NAME = None DATASET_NAME = None
def __init__(self, subject): def __init__(self, subject):
...@@ -121,62 +275,63 @@ class CmmluSubject(MultipleChoiceTask): ...@@ -121,62 +275,63 @@ class CmmluSubject(MultipleChoiceTask):
return False return False
def has_validation_docs(self): def has_validation_docs(self):
return True return False
def has_test_docs(self): def has_test_docs(self):
return True return True
def validation_docs(self):
if self.has_validation_docs():
return map(self._process_doc, self.dataset["dev"])
def test_docs(self): def test_docs(self):
if self.has_test_docs():
return map(self._process_doc, self.dataset["test"]) return map(self._process_doc, self.dataset["test"])
def _format_subject(self, subject):
words = subject.split("_")
return " ".join(words)
def fewshot_context(self, doc, num_fewshot, **kwargs): def fewshot_context(self, doc, num_fewshot, **kwargs):
subject = self.DATASET_NAME subject = self.DATASET_NAME
description = f"以下是关于{SUBJECTS[subject]}的单项选择题,请直接给出正确答案的选项。" description = f"以下是关于{SUBJECT_MAPPING[subject]}的单项选择题,请直接给出正确答案的选项。"
kwargs["description"] = description kwargs["description"] = description
return super().fewshot_context(doc=doc, num_fewshot=num_fewshot, **kwargs) return super().fewshot_context(doc=doc, num_fewshot=num_fewshot, **kwargs)
def _process_doc(self, doc): def _process_doc(self, doc):
def format_example(doc, keys): def format_example(doc, keys):
""" """
<prompt> 题目:<prompt>
A. <choice1> A. <choice1>
B. <choice2> B. <choice2>
C. <choice3> C. <choice3>
D. <choice4> D. <choice4>
答案: 答案
""" """
question = doc["Question"].strip() question = doc["Question"].strip()
choices = "".join([f"{key}. {doc[key]}\n" for key in keys]) choices = "".join(
prompt = f"{question}\n{choices}答案:" [f"{key}. {doc[key]}\n" for key in keys]
)
prompt = f"题目:{question}\n{choices}答案是:"
return prompt return prompt
keys = ["A", "B", "C", "D"] keys = ["A", "B", "C", "D"]
return { return {
"query": format_example(doc, keys), "query": format_example(doc, keys),
"choices": keys, "choices": keys,
"gold": ord(doc["Answer"]) - ord("A"), "gold": keys.index(doc["Answer"]),
} }
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
if self._fewshot_docs is None: if self._fewshot_docs is None:
self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"])) self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"]))
# use the unchanged order of the dev set without sampling,
return self._fewshot_docs[:k] return self._fewshot_docs[:k]
def construct_requests(self, doc, ctx):
lls = [
rf.loglikelihood(ctx, "{}".format(choice))[0] for choice in doc["choices"]
]
return lls
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
def doc_to_target(self, doc):
return doc["choices"][doc["gold"]]
def should_decontaminate(self): def should_decontaminate(self):
return True return True
......
...@@ -33,34 +33,43 @@ _CITATION = """ ...@@ -33,34 +33,43 @@ _CITATION = """
class Pubmed_QA(Task): class Pubmed_QA(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "pubmed_qa" DATASET_PATH = "bigbio/pubmed_qa"
DATASET_NAME = "pqa_labeled" DATASET_NAME = "pubmed_qa_labeled_fold0_source"
def has_training_docs(self): def has_training_docs(self):
return False return True
def has_validation_docs(self): def has_validation_docs(self):
return False return True
def has_test_docs(self): def has_test_docs(self):
return True return True
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = self.dataset["train"]
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
# HF is labelled as train but its really just for testing return self.dataset["test"]
return self.dataset["train"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
ctxs = "\n".join(doc["context"]["contexts"]) ctxs = "\n".join(doc["CONTEXTS"])
return "Abstract: {}\nQuestion: {}\nAnswer:".format( return "Abstract: {}\nQuestion: {}\nAnswer:".format(
ctxs, doc["question"], doc["final_decision"] ctxs, doc["QUESTION"], doc["final_decision"]
) )
def should_decontaminate(self): def should_decontaminate(self):
return True return True
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc):
return doc["question"] + " " + "\n".join(doc["context"]["contexts"]) return doc["question"] + " " + "\n".join(doc["CONTEXTS"])
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " {}".format(doc["final_decision"]) return " {}".format(doc["final_decision"])
......
...@@ -12,7 +12,7 @@ setuptools.setup( ...@@ -12,7 +12,7 @@ setuptools.setup(
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
url="https://github.com/EleutherAI/lm-evaluation-harness", url="https://github.com/EleutherAI/lm-evaluation-harness",
packages=setuptools.find_packages(), packages=setuptools.find_packages(exclude=["scripts.*", "scripts"]),
package_data={"lm_eval": ["**/*.json"]}, package_data={"lm_eval": ["**/*.json"]},
include_package_data=True, include_package_data=True,
classifiers=[ classifiers=[
......
import unittest
from unittest.mock import patch
import hashlib
import json
import os
import pickle
from lm_eval.models.gguf import GGUFLM
base_url = "https://matthoffner-ggml-llm-api.hf.space"
def gguf_completion_mock(base_url, **kwargs):
# Generate a hash from the parameters
hash_kwargs = {'base_url': base_url, **kwargs}
hash = hashlib.sha256(json.dumps(hash_kwargs, sort_keys=True).encode('utf-8')).hexdigest()
fname = f"./tests/testdata/ggml_test_{hash}.pkl"
if os.path.exists(fname):
with open(fname, "rb") as fh:
return pickle.load(fh)
else:
print("The file does not exist, attempting to write...")
if 'stop' in kwargs:
result = {"choices": [{"text": f"generated text until {kwargs['stop']}", "logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]}
else:
result = {"choices": [{"logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]}
try:
os.makedirs(os.path.dirname(fname), exist_ok=True)
print('Writing file at', fname)
with open(fname, "wb") as fh:
pickle.dump(result, fh)
print('File written successfully')
except Exception as e:
print('File writing failed:', e)
return result
class GGUFLMTest(unittest.TestCase):
@patch('lm_eval.models.gguf.GGUFLM.gguf_completion', side_effect=gguf_completion_mock)
def test_loglikelihood(self, gguf_completion_mock):
lm = GGUFLM(base_url)
# Test loglikelihood
requests = [("context1", "continuation1"), ("context2", "continuation2")]
res = lm.loglikelihood(requests)
# Assert the loglikelihood response is correct
expected_res = [(logprob, True) for logprob in [-1.2345, -1.2345]]
self.assertEqual(res, expected_res)
@patch('lm_eval.models.gguf.GGUFLM.gguf_completion', side_effect=gguf_completion_mock)
def test_greedy_until(self, gguf_completion_mock):
lm = GGUFLM(base_url)
# Test greedy_until
requests = [("input1", {"until": "stop1"}), ("input2", {"until": "stop2"})]
res = lm.greedy_until(requests)
# Assert the greedy_until response is correct
expected_res = ["generated text until stop1", "generated text until stop2"]
self.assertEqual(res, expected_res)
@patch('lm_eval.models.gguf.GGUFLM.gguf_completion', side_effect=gguf_completion_mock)
def test_loglikelihood_rolling(self, gguf_completion_mock):
lm = GGUFLM(base_url)
# Test loglikelihood_rolling
requests = ["input1", "input2"]
res = lm.loglikelihood_rolling(requests)
# Assert the loglikelihood_rolling response is correct
expected_res = [(-1.2345, True), (-1.2345, True)]
self.assertEqual(res, expected_res)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment