Commit 4b133dca authored by Leo Gao's avatar Leo Gao
Browse files

Merge branch 'master' of github.com:EleutherAI/lm_evaluation_harness into cfsquad

# Conflicts:
#	lm_eval/tasks/squad.py
parents 8de85534 caba51e1
......@@ -59,6 +59,13 @@ The goal of this project is to build a set of tools for evaluating LMs on typica
|ethics_utilitarianism_original|✓ |✓ |✓ |acc |
|ethics_utilitarianism |✓ |✓ |✓ |acc |
|ethics_virtue |✓ |✓ |✓ |acc, em |
|math_algebra |✓ | |✓ |acc |
|math_counting_and_prob |✓ | |✓ |acc |
|math_geometry |✓ | |✓ |acc |
|math_intermediate_algebra |✓ | |✓ |acc |
|math_num_theory |✓ | |✓ |acc |
|math_prealgebra |✓ | |✓ |acc |
|math_precalc |✓ | |✓ |acc |
|arithmetic_2da | |✓ | |acc |
|arithmetic_2ds | |✓ | |acc |
|arithmetic_3da | |✓ | |acc |
......
......@@ -73,6 +73,7 @@ class Task(abc.ABC):
def __init__(self):
self.download()
self._training_docs = None
self._fewshot_docs = None
def download(self):
"""Downloads the task dataset if necessary"""
......@@ -114,10 +115,11 @@ class Task(abc.ABC):
"""
return []
def fewshot_examples(self, k):
def fewshot_examples(self, k, rnd):
if self._training_docs is None:
self._training_docs = list(self.training_docs())
return random.sample(self._training_docs, k)
return rnd.sample(self._training_docs, k)
@abc.abstractmethod
def doc_to_text(self, doc):
......@@ -175,15 +177,27 @@ class Task(abc.ABC):
def fewshot_description(self):
return ""
def fewshot_context(self, doc, num_fewshot, provide_description):
def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
raw_description = self.fewshot_description()
description = (raw_description + "\n===\n\n") if provide_description and raw_description else ""
if num_fewshot == 0:
labeled_examples = ""
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(self.validation_docs() if self.has_validation_docs else self.test_docs())
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = "\n\n".join(
[self.doc_to_text(doc) + self.doc_to_target(doc) for doc in self.fewshot_examples(k=num_fewshot)]
[self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex]
) + "\n\n"
example = self.doc_to_text(doc)
......
......@@ -23,12 +23,12 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
# get lists of each type of requeste
for task_name, task in task_dict_items:
#default to validation doc, fall back to test doc if validation unavailable
# TODO: the val-fallback-to-test system isn't final, we should revisit it at some point
if task.has_validation_docs():
task_doc_func = task.validation_docs
elif task.has_test_docs():
#default to test doc, fall back to val doc if validation unavailable
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
if task.has_test_docs():
task_doc_func = task.test_docs
elif task.has_validation_docs():
task_doc_func = task.validation_docs
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
task_docs = list(task_doc_func())
......@@ -43,6 +43,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
doc=doc,
provide_description=provide_description,
num_fewshot=num_fewshot,
rnd=rnd
)
reqs = task.construct_requests(doc, ctx)
......
import math
from collections import Iterable
from pprint import pprint
import numpy as np
import sacrebleu
......
......@@ -9,19 +9,25 @@ from tqdm import tqdm
class GPT2LM(LM):
MAX_GEN_TOKS = 256
def __init__(self, device="cpu", pretrained='gpt2'):
self.device = torch.device(device)
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained(pretrained).to(self.device)
def __init__(self, device=None, pretrained='gpt2'):
if device:
self.device = torch.device(device)
else:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device)
self.gpt2.eval()
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained(pretrained)
# pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
self.tokenizer.pad_token = "<|endoftext|>"
self.max_length = self.gpt2.config.n_ctx
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
@classmethod
def create_from_arg_string(cls, arg_string):
args = utils.simple_parse_args_string(arg_string)
return cls(device=args.get("device", "cpu"), pretrained=args.get("pretrained", "gpt2"))
return cls(device=args.get("device", None), pretrained=args.get("pretrained", "gpt2"))
def loglikelihood(self, requests):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
......@@ -29,7 +35,13 @@ class GPT2LM(LM):
with torch.no_grad():
# TODO: vectorize properly
# TODO: automatic batch size detection for vectorization
for context, continuation in tqdm(requests):
def _collate(x):
toks = self.tokenizer.encode(x[0] + x[1])[:-1]
return (len(toks), self.tokenizer.decode(toks))
reord = utils.Reorderer(requests, _collate)
for context, continuation in tqdm(reord.get_reordered()):
# when too long to fit in context, truncate from the left
if context == "":
......@@ -39,8 +51,8 @@ class GPT2LM(LM):
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
......@@ -48,22 +60,32 @@ class GPT2LM(LM):
greedy_tokens = logits.argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
res.append((float(logits.sum()), bool(max_equal)))
res.append((float(logits[:, :-1].sum() if logits.shape[-1] > 1 else 0), last_token_slice, bool(max_equal)))
return res
# optimization: if two requests have everything the same except the last token, use
# last token distribution to save compute
lasttoks = [self.tokenizer.encode(x[1])[-1] for x in requests]
return [(l + lts[lasttok], m) for (l, lts, m), lasttok in zip(reord.get_original(res), lasttoks)]
def greedy_until(self, requests):
# TODO: implement fully general `until` that handles untils that are
# multiple tokens or that span multiple tokens correctly
res = []
for context, until in tqdm(requests):
def _collate(x):
toks = self.tokenizer.encode(x[0])
return (len(toks), x[0])
reord = utils.Reorderer(requests, _collate)
for context, until in tqdm(reord.get_reordered()):
if isinstance(until, str): until = [until]
context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - 1024:]]).to(self.device)
context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device)
primary_until, = self.tokenizer.encode(until[0])
......@@ -81,4 +103,4 @@ class GPT2LM(LM):
res.append(s)
return res
return reord.get_original(res)
......@@ -70,7 +70,16 @@ class GPT3LM(LM):
import openai
res = []
for chunk in tqdm(list(utils.chunks(requests, self.REQ_CHUNK_SIZE))):
def _collate(x):
# this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
# it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
# we care about and so we need some kind of backup for when it isn't
toks = self.tokenizer.encode(x[0] + x[1])
return (len(toks), self.tokenizer.decode(toks))
reord = utils.Reorderer(requests, _collate)
for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
inps = []
ctxlens = []
for context, continuation in chunk:
......@@ -98,13 +107,19 @@ class GPT3LM(LM):
for resp, ctxlen in zip(response.choices, ctxlens):
res.append(get_result(resp, ctxlen))
return res
return reord.get_original(res)
def greedy_until(self, requests):
if not requests: return []
import openai
res = []
def _collate(x):
toks = self.tokenizer.encode(x[0])
return (len(toks), x[0])
reord = utils.Reorderer(requests, _collate)
def sameuntil_chunks(xs, size):
ret = []
lastuntil = xs[0][1]
......@@ -118,7 +133,7 @@ class GPT3LM(LM):
if ret: yield ret, lastuntil
# todo: more intelligent batching for heterogenous `until`
for chunk, until in tqdm(list(sameuntil_chunks(requests, self.REQ_CHUNK_SIZE))):
for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
inps = []
for context, _ in chunk:
context_enc = self.tokenizer.encode(context)
......@@ -142,5 +157,5 @@ class GPT3LM(LM):
res.append(s)
return res
return reord.get_original(res)
......@@ -33,6 +33,8 @@ from . import ethics
from . import drop
from . import unscramble
from . import logiqa
from . import hendrycks_test
from . import math
########################################
# Translation tasks
......@@ -126,6 +128,15 @@ TASK_REGISTRY = {
"ethics_utilitarianism": ethics.EthicsUtilitarianism,
"ethics_virtue": ethics.EthicsVirtue,
# math
"math_algebra": math.MathAlgebra,
"math_counting_and_prob": math.MathCountingAndProbability,
"math_geometry": math.MathGeometry,
"math_intermediate_algebra": math.MathIntermediateAlgebra,
"math_num_theory": math.MathNumberTheory,
"math_prealgebra": math.MathPrealgebra,
"math_precalc": math.MathPrecalculus,
# arithmetic
"arithmetic_2da": arithmetic.Arithmetic2DPlus,
"arithmetic_2ds": arithmetic.Arithmetic2DMinus,
......@@ -140,6 +151,9 @@ TASK_REGISTRY = {
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks)
**hendrycks_test.create_all_tasks(),
# e.g. wmt14-fr-en
**translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# chef's selection, mostly wmt20
......
......@@ -3,6 +3,7 @@ from lm_eval.base import rf
from ..metrics import mean
from . common import HFTask
class ANLIBase(HFTask):
DATASET_PATH = "anli"
DATASET_NAME = None
......
import numpy as np
from lm_eval.base import MultipleChoiceTask
from ..metrics import mean
from . common import HFTask
......
......@@ -8,6 +8,7 @@ from best_download import download_file
ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
class Arithmetic(Task):
directory = 'data/arithmetic/'
......
import datasets
import numpy as np
import lm_eval.metrics
from ..base import Task
......
import os
import json
import transformers.data.metrics.squad_metrics as squad_metrics
from lm_eval.base import Task, rf, mean
from ..utils import sh
from itertools import zip_longest
import transformers.data.metrics.squad_metrics as squad_metrics
import collections
import datasets
import numpy as np
from lm_eval.base import rf, mean
from . common import HFTask
from tqdm import tqdm
import string, re
class CoQA(Task):
......
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
from lm_eval.utils import sh
from .common import yesno
import abc
import csv
import os
import random
import numpy as np
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
from lm_eval.utils import sh
from .common import yesno
class Ethics(Task):
def download(self):
......@@ -218,7 +218,7 @@ class EthicsUtilitarianismOriginal(Ethics):
def fewshot_description(self):
return "Rate how pleasant each of the following activities is on a scale from 1 (very unpleasant) to 10 (very pleasant).\n\n"
def fewshot_examples(self, k):
def fewshot_examples(self, k, rnd):
# Overwriting fewshot examples as k can be max 5
assert k <= 5, "There are only 5 possible shots for this task. Refer to the V2 for more."
# These prompts implicitly (under)specify the task utility function
......
import numpy as np
from lm_eval.base import rf
from ..metrics import mean, matthews_corrcoef, f1_score
from scipy.stats import pearsonr, spearmanr
from tqdm import auto as tqdm_lib
from . common import HFTask, yesno
from ..utils import general_detokenize
......@@ -20,7 +18,7 @@ class CoLA(HFTask):
return True
def has_test_docs(self):
return True
return False
def fewshot_description(self):
# TODO
......@@ -67,7 +65,7 @@ class SST(HFTask):
return True
def has_test_docs(self):
return True
return False
def fewshot_description(self):
return "Indicate if the sentiment of each sentence is positive or negative."
......@@ -118,7 +116,7 @@ class MNLI(HFTask):
return True
def has_test_docs(self):
return True
return False
def validation_docs(self):
if self.has_validation_docs():
......@@ -186,7 +184,7 @@ class QNLI(HFTask):
return True
def has_test_docs(self):
return True
return False
def doc_to_text(self, doc):
return "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format(
......@@ -234,7 +232,7 @@ class WNLI(HFTask):
return True
def has_test_docs(self):
return True
return False
def doc_to_text(self, doc):
return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format(
......@@ -283,7 +281,7 @@ class RTE(HFTask):
return True
def has_test_docs(self):
return True
return False
def doc_to_text(self, doc):
return "{}\nQuestion: {} True or False?\nAnswer:".format(
......@@ -334,7 +332,7 @@ class MRPC(HFTask):
return True
def has_test_docs(self):
return True
return False
def fewshot_description(self):
return "Indicate if both sentences mean the same thing."
......@@ -386,7 +384,7 @@ class QQP(HFTask):
return True
def has_test_docs(self):
return True
return False
def fewshot_description(self):
return "Indicate if both questions ask the same thing."
......
import csv
import random
from lm_eval.base import MultipleChoiceTask
from ..utils import sh
from pathlib import Path
SUBJECTS = ['abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology',
'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics',
'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics',
'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science',
'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics',
'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics',
'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence',
'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes',
'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine',
'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions']
def create_all_tasks():
"""Creates a dictionary of tasks from a list of subjects
:return: {task_name: task}
e.g. {hendrycksTest-abstract_algebra: Task, hendrycksTest-anatomy: Task}
"""
return {
f"hendrycksTest-{sub}": create_task(sub) for sub in SUBJECTS
}
def create_task(subject):
class HendrycksTest(GeneralHendrycksTest):
def __init__(self):
super().__init__(subject)
return HendrycksTest
class GeneralHendrycksTest(MultipleChoiceTask):
DATASET_PATH = Path("data/hendrycksTest/")
def __init__(self, subject):
self.subject = subject
super().__init__()
def download(self):
if not self.DATASET_PATH.exists():
sh("""
mkdir -p data
wget -c https://people.eecs.berkeley.edu/~hendrycks/data.tar -P data/
tar -xf data/data.tar -C data/
rm data/data.tar
mv data/data data/hendrycksTest
""")
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def _convert_standard(self, doc):
def format_example(doc, choices):
"""
Question: <prompt>
A. <choice1>
B. <choice2>
C. <choice3>
D. <choice4>
Answer:
"""
prompt = "Question: " + doc[0] + "\n"
prompt += "".join([f"{choices[j]}. {doc[j+1]}\n" for j in range(4)])
prompt += "Answer:"
return prompt
choices = ['A', 'B', 'C', 'D']
return {
"query": format_example(doc, choices),
"choices": doc[1:5],
"gold": choices.index(doc[5])
}
def _load_docs(self, filename):
reader = csv.reader(open(filename, 'r'), quotechar='"', delimiter=',')
return (self._convert_standard(doc) for doc in reader)
def training_docs(self):
docs = []
for train_dir in ["auxiliary_train", "dev"]:
for f in (self.DATASET_PATH / train_dir).iterdir():
docs.extend(self._load_docs(f))
return docs
def validation_docs(self):
filename = self.DATASET_PATH / "val" / f"{self.subject}_val.csv"
return self._load_docs(filename)
def test_docs(self):
filename = self.DATASET_PATH / "test" / f"{self.subject}_test.csv"
return self._load_docs(filename)
def fewshot_examples(self, k, rnd):
# fewshot_examples is not just sampling from train_docs because dev is
# in the same distribution as val/test but auxiliary_train isn't
filename = self.DATASET_PATH / "dev" / f"{self.subject}_dev.csv"
if self._fewshot_docs is None:
self._fewshot_docs = list(self._load_docs(filename))
return rnd.sample(list(self._fewshot_docs), k)
def fewshot_description(self):
subject = self.subject.replace("_", " ")
return f"The following are multiple choice questions (with answers) about {subject}."
def doc_to_text(self, doc):
return doc["query"]
import json
from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh
import json
import math
from best_download import download_file
......
......@@ -30,10 +30,27 @@ class LogiQA(MultipleChoiceTask):
return True
def _convert_standard(self, doc):
def format_example(doc, choices):
"""
Passage: <passage>
Question: <question>
A. <choice1>
B. <choice2>
C. <choice3>
D. <choice4>
Answer:
"""
prompt = "Passage: " + doc["passage"] + "\n"
prompt += "Question: " + doc["question"] + "\n"
for choice, option in zip(choices, doc["options"]):
prompt += f"{choice.upper()}. {option}\n"
prompt += "Answer:"
return prompt
choices = ['a', 'b', 'c', 'd']
return {
"query": "Passage: " + doc["passage"] + "\nQuestion: " + doc["question"] + "\nAnswer:",
"query": format_example(doc, choices),
"choices": doc["options"],
"gold": ["a", "b", "c", "d"].index(doc["answerKey"])
"gold": choices.index(doc["answerKey"])
}
def _load_docs(self, filename):
......
import abc
import json
from lm_eval.utils import sh
from lm_eval.metrics import mean
from lm_eval.base import Task, rf
from pathlib import Path
class Math(Task):
"""
This dataset is based on the following paper:
https://arxiv.org/abs/2103.03874
"""
DATASET_PATH = Path('data/MATH')
def download(self):
if not self.DATASET_PATH.exists():
sh(f"""
mkdir -p {self.DATASET_PATH}
wget https://people.eecs.berkeley.edu/~hendrycks/MATH.tar.gz -P data/
tar -xvf {self.DATASET_PATH}.tar.gz -C data/
rm {self.DATASET_PATH}.tar.gz
""")
@abc.abstractmethod
def get_file_info(self):
"""returns directory name"""
pass
def has_training_docs(self):
return True
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def _load_docs(self, path):
for file in path.iterdir():
with open(file) as f:
doc = json.load(f)
doc["answer"] = self.remove_boxed(
self.last_boxed_only_string(doc["solution"]))
yield doc
def training_docs(self):
return self._load_docs(self.DATASET_PATH / "train" / self.get_file_info())
def validation_docs(self):
return NotImplemented
def test_docs(self):
return self._load_docs(self.DATASET_PATH / "test" / self.get_file_info())
def fewshot_description(self):
return "Given a mathematics problem, determine the answer. Simplify your answer as much as possible."
def doc_to_text(self, doc):
return "Problem: " + doc["problem"] + "\nAnswer:"
def doc_to_target(self, doc):
return " " + doc["answer"]
def construct_requests(self, doc, ctx):
return rf.greedy_until(ctx, ["\n"])
def process_results(self, doc, results):
retval = 0
indices = [pos for pos, char in enumerate(results[0]) if char == "$"]
if len(indices) <= 1:
answer = results[0]
else:
answer = results[0][indices[0]+1:indices[-1]]
if self.is_equiv(answer, self.remove_boxed(self.last_boxed_only_string(doc["solution"]))):
retval = 1
return {
"acc": retval
}
def aggregation(self):
return {
'acc': mean
}
def higher_is_better(self):
return {
'acc': True
}
def is_equiv(self, str1, str2, verbose=False):
if str1 is None and str2 is None:
print("WARNING: Both None")
return True
if str1 is None or str2 is None:
return False
try:
ss1 = self.strip_string(str1)
ss2 = self.strip_string(str2)
if verbose:
print(ss1, ss2)
return ss1 == ss2
except:
return str1 == str2
def remove_boxed(self, s):
left = "\\boxed{"
try:
assert s[:len(left)] == left
assert s[-1] == "}"
return s[len(left):-1]
except AssertionError:
return None
def last_boxed_only_string(self, string):
idx = string.rfind("\\boxed")
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx is None:
retval = None
else:
retval = string[idx:right_brace_idx + 1]
return retval
def fix_fracs(self, string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except AssertionError:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def fix_a_slash_b(self, string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except AssertionError:
return string
def remove_right_units(self, string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
def fix_sqrt(self, string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
class NotEqual:
def __eq__(self, other):
return False
def strip_string(self, string):
# linebreaks
string = string.replace("\n", "")
# remove inverse spaces
string = string.replace("\\!", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = self.remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
# fix sqrt3 --> sqrt{3}
string = self.fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = self.fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = self.fix_a_slash_b(string)
return string
class MathAlgebra(Math):
def get_file_info(self):
return 'algebra'
class MathCountingAndProbability(Math):
def get_file_info(self):
return 'counting_and_probability'
class MathGeometry(Math):
def get_file_info(self):
return 'geometry'
class MathIntermediateAlgebra(Math):
def get_file_info(self):
return 'intermediate_algebra'
class MathNumberTheory(Math):
def get_file_info(self):
return 'number_theory'
class MathPrealgebra(Math):
def get_file_info(self):
return 'prealgebra'
class MathPrecalculus(Math):
def get_file_info(self):
return 'precalculus'
from . common import HFTask
from lm_eval.base import mean, rf, MultipleChoiceTask
import re
from lm_eval.base import MultipleChoiceTask
from . common import HFTask
class MathQA(HFTask, MultipleChoiceTask):
DATASET_PATH = "math_qa"
......
import random
from . common import HFTask
from itertools import islice
import random
class NaturalQs(HFTask):
# TODO: naturalqs has a *really* large train set that huggingface just
......@@ -28,12 +29,12 @@ class NaturalQs(HFTask):
# Data is too large to fit in memory.
return self.data["train"]
def fewshot_examples(self, k):
def fewshot_examples(self, k, rnd):
# Data is too large to fit in memory. We just sample from the first bit.
if self._training_docs is None:
self._training_docs = list(islice(self.training_docs(), 0, 100000))
return random.sample(self._training_docs, k)
return rnd.sample(self._training_docs, k)
def doc_to_text(self, doc):
return 'Q: ' + doc['question']['text'] + '\n\n' + 'A: '
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment