Commit 60c9c170 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

Merge branch 'main' into inverse-scaling-tasks

parents 4b2d565b b4cd85d4
include: hendrycks_math_algebra.yaml
dataset_name: counting_and_probability
task: hendrycks_math_counting_and_prob
include: hendrycks_math_algebra.yaml
dataset_name: geometry
task: hendrycks_math_geometry
include: hendrycks_math_algebra.yaml
dataset_name: intermediate_algebra
task: hendrycks_math_intermediate_algebra
include: hendrycks_math_algebra.yaml
dataset_name: number_theory
task: hendrycks_math_num_theory
include: hendrycks_math_algebra.yaml
dataset_name: prealgebra
task: hendrycks_math_prealgebra
include: hendrycks_math_algebra.yaml
dataset_name: precalculus
task: hendrycks_math_precalc
from typing import Dict, List
import datasets
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
def _process_doc(doc: dict) -> dict:
out_doc = {
"problem": doc["problem"],
"solution": doc["solution"],
"answer": remove_boxed(last_boxed_only_string(doc["solution"])),
}
return out_doc
return dataset.map(_process_doc)
def process_results(doc: dict, results: List[str]) -> Dict[str, int]:
retval = 0
indices = [pos for pos, char in enumerate(results[0]) if char == "$"]
if len(indices) <= 1:
answer = results[0]
else:
answer = results[0][indices[0] + 1 : indices[-1]]
if is_equiv(answer, remove_boxed(last_boxed_only_string(doc["solution"]))):
retval = 1
results = {
"exact_match": retval,
}
return results
# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
def is_equiv(str1, str2, verbose=False):
if str1 is None and str2 is None:
print("WARNING: Both None")
return True
if str1 is None or str2 is None:
return False
try:
ss1 = strip_string(str1)
ss2 = strip_string(str2)
if verbose:
print(ss1, ss2)
return ss1 == ss2
except Exception:
return str1 == str2
def remove_boxed(s):
if "\\boxed " in s:
left = "\\boxed "
assert s[: len(left)] == left
return s[len(left) :]
left = "\\boxed{"
assert s[: len(left)] == left
assert s[-1] == "}"
return s[len(left) : -1]
def last_boxed_only_string(string):
idx = string.rfind("\\boxed")
if "\\boxed " in string:
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx is None:
retval = None
else:
retval = string[idx : right_brace_idx + 1]
return retval
def fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except AssertionError:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except AssertionError:
return string
def remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
def fix_sqrt(string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
def strip_string(string):
# linebreaks
string = string.replace("\n", "")
# remove inverse spaces
string = string.replace("\\!", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "") # noqa: W605
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
# fix sqrt3 --> sqrt{3}
string = fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = fix_a_slash_b(string)
return string
...@@ -12,7 +12,6 @@ generation_kwargs: ...@@ -12,7 +12,6 @@ generation_kwargs:
temperature: 0.0 temperature: 0.0
max_gen_toks: 1280 max_gen_toks: 1280
process_results: !function utils.process_results process_results: !function utils.process_results
num_fewshot: 0
metric_list: metric_list:
- metric: prompt_level_strict_acc - metric: prompt_level_strict_acc
aggregation: mean aggregation: mean
......
...@@ -28,16 +28,11 @@ Eprint = {arXiv:2206.14858}, ...@@ -28,16 +28,11 @@ Eprint = {arXiv:2206.14858},
} }
``` ```
### Groups, Benchmarks and Tasks ### Groups and Tasks
#### Benchmarks
- `minerva_math`
#### Groups #### Groups
- `math_word_problems` - `minerva_math`
- `generate_until`
#### Tasks #### Tasks
......
...@@ -2,12 +2,14 @@ ...@@ -2,12 +2,14 @@
Take in a YAML, and output all "other" splits with this YAML Take in a YAML, and output all "other" splits with this YAML
""" """
import argparse import argparse
import logging
import os import os
import yaml import yaml
from tqdm import tqdm from tqdm import tqdm
from lm_eval.logger import eval_logger
eval_logger = logging.getLogger("lm-eval")
SUBJECTS = { SUBJECTS = {
......
dataset_path: hails/mmlu_no_train # a copy of `cais/mmlu` with no auxiliary_train split
output_type: multiple_choice
test_split: test
fewshot_split: dev
fewshot_config:
sampler: first_n
doc_to_text: "Question: {{question.strip()}}\nAnswer:"
doc_to_choice: "{{choices}}"
doc_to_target: "{{answer}}"
metadata:
version: 0.0
group: mmlu_continuation
task:
- mmlu_continuation_stem
- mmlu_continuation_other
- mmlu_continuation_social_sciences
- mmlu_continuation_humanities
"dataset_name": "abstract_algebra"
"description": "The following are questions (with answers) about abstract\
\ algebra.\n\n"
"group": "mmlu_continuation_stem"
"include": "_continuation_template_yaml"
"task": "mmlu_continuation_abstract_algebra"
"dataset_name": "anatomy"
"description": "The following are questions (with answers) about anatomy.\n\
\n"
"group": "mmlu_continuation_stem"
"include": "_continuation_template_yaml"
"task": "mmlu_continuation_anatomy"
"dataset_name": "astronomy"
"description": "The following are questions (with answers) about astronomy.\n\
\n"
"group": "mmlu_continuation_stem"
"include": "_continuation_template_yaml"
"task": "mmlu_continuation_astronomy"
"dataset_name": "business_ethics"
"description": "The following are questions (with answers) about business\
\ ethics.\n\n"
"group": "mmlu_continuation_other"
"include": "_continuation_template_yaml"
"task": "mmlu_continuation_business_ethics"
"dataset_name": "clinical_knowledge"
"description": "The following are questions (with answers) about clinical\
\ knowledge.\n\n"
"group": "mmlu_continuation_other"
"include": "_continuation_template_yaml"
"task": "mmlu_continuation_clinical_knowledge"
"dataset_name": "college_biology"
"description": "The following are questions (with answers) about college\
\ biology.\n\n"
"group": "mmlu_continuation_stem"
"include": "_continuation_template_yaml"
"task": "mmlu_continuation_college_biology"
"dataset_name": "college_chemistry"
"description": "The following are questions (with answers) about college\
\ chemistry.\n\n"
"group": "mmlu_continuation_stem"
"include": "_continuation_template_yaml"
"task": "mmlu_continuation_college_chemistry"
"dataset_name": "college_computer_science"
"description": "The following are questions (with answers) about college\
\ computer science.\n\n"
"group": "mmlu_continuation_stem"
"include": "_continuation_template_yaml"
"task": "mmlu_continuation_college_computer_science"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment