Unverified Commit 5ac7cdf8 authored by Janna's avatar Janna Committed by GitHub
Browse files

Support for AIME dataset (#3248)

* add AIME tasks

* standardize the repeats

* fix task naming

* aime25 only has test set

* edit readme

* add utils

* standardize

* fix case sensitivity

* repeat once

* lint

* more linting

* lint huggingface.py
parent 0b45cc71
......@@ -682,11 +682,13 @@ class HFLM(TemplateLM):
raise AssertionError("load_in_4bit requires peft >= 0.4.0")
# Compatible with Gemma3 (multimodal) and old models
if hasattr(self._model.config, "text_config") and hasattr(self._model.config.text_config, "vocab_size"):
if hasattr(self._model.config, "text_config") and hasattr(
self._model.config.text_config, "vocab_size"
):
vocab_size = self._model.config.text_config.vocab_size
else:
vocab_size = self._model.config.vocab_size
if vocab_size != len(self.tokenizer):
# resize model for LoRAs with added tokens
eval_logger.info(
......
# AIME
### Citation
```text
@dataset{aime_1983_2024,
author = {Hemish Veeraboina},
title = {AIME Problem Set 1983-2024},
year = {2024},
publisher = {Kaggle},
url = {https://www.kaggle.com/datasets/hemishveeraboina/aime-problem-set-1983-2024}
}
@dataset{aime_2024,
author = {Maxwell Jia},
title = {AIME Problem Set 2024},
year = {2024},
publisher = {Huggingface},
url = {https://huggingface.co/datasets/Maxwell-Jia/AIME_2024}
}
@dataset{aime_2025,
author = {math-ai},
title = {AIME Problem Set 2025},
year = {2025},
publisher = {Huggingface},
url = {https://huggingface.co/datasets/math-ai/aime25}
}
```
### Groups, Tags, and Tasks
#### Groups
* `math_word_problems`
#### Tasks
* `aime`: `AIME 1983-2024 problems`
* `aime24`: `AIME 2024 problems`
* `aime25`: `AIME 2025 problems`
### Checklist
For adding novel benchmarks/datasets to the library:
* [x] Is the task an existing benchmark in the literature?
* [x] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
tag:
- math_word_problems
task: aime
dataset_path: gneubig/aime-1983-2024
# dataset_name: null
output_type: generate_until
training_split: train
fewshot_split: train
test_split: train
doc_to_text: "Question: {{Question}}\nAnswer:"
doc_to_target: "{{Answer}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
generation_kwargs:
until:
- "Question:"
- "</s>"
- "<|im_end|>"
- "<|eot_id|>"
do_sample: false
temperature: 0.0
max_gen_toks: 32768
repeats: 1
num_fewshot: 0
metadata:
version: 0.0
tag:
- math_word_problems
task: aime24
dataset_path: Maxwell-Jia/AIME_2024
# dataset_name: null
output_type: generate_until
training_split: train
fewshot_split: train
test_split: train
doc_to_text: "Question: {{Problem}}\nAnswer:"
doc_to_target: "{{Answer}}"
process_results: !function utils.process_results
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
generation_kwargs:
until:
- "Question:"
- "</s>"
- "<|im_end|>"
- "<|eot_id|>"
do_sample: false
temperature: 0.0
max_gen_toks: 32768
repeats: 1
num_fewshot: 0
metadata:
version: 0.0
tag:
- math_word_problems
task: aime25
dataset_path: math-ai/aime25
# dataset_name: null
output_type: generate_until
training_split: test
fewshot_split: test
test_split: test
doc_to_text: "Question: {{problem}}\nAnswer:"
doc_to_target: "{{answer}}"
process_results: !function utils.process_results
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
generation_kwargs:
until:
- "Question:"
- "</s>"
- "<|im_end|>"
- "<|eot_id|>"
do_sample: false
temperature: 0.0
max_gen_toks: 32768
repeats: 1
num_fewshot: 0
metadata:
version: 0.0
import re
from typing import Dict, List
def process_results(doc: dict, results: List[str]) -> Dict[str, int]:
retval = 0
response = results[0]
# Try to extract answer from $...$ format first
indices = [pos for pos, char in enumerate(response) if char == "$"]
if len(indices) <= 1:
answer = response
else:
answer = response[indices[0] + 1 : indices[-1]]
# Extract from \\boxed{} if present
boxed_answer = last_boxed_only_string(response)
if boxed_answer is not None:
try:
boxed_content = remove_boxed(boxed_answer)
if boxed_content is not None:
answer = boxed_content
except (AssertionError, IndexError):
pass
# Check if answer matches target
answer_key = next(k for k in doc.keys() if k.lower() == "answer")
target = str(doc[answer_key])
if is_equiv(answer, target):
retval = 1
return {"exact_match": retval}
# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
def is_equiv(str1, str2, verbose=False):
if str1 is None and str2 is None:
print("WARNING: Both None")
return True
if str1 is None or str2 is None:
return False
try:
ss1 = strip_string(str1)
ss2 = strip_string(str2)
if verbose:
print(ss1, ss2)
return ss1 == ss2
except Exception:
return str1 == str2
def remove_boxed(s):
if "\\boxed " in s:
left = "\\boxed "
assert s[: len(left)] == left
return s[len(left) :]
left = "\\boxed{"
assert s[: len(left)] == left
assert s[-1] == "}"
return s[len(left) : -1]
def last_boxed_only_string(string):
idx = string.rfind("\\boxed")
if "\\boxed " in string:
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx is None:
retval = None
else:
retval = string[idx : right_brace_idx + 1]
return retval
def fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except AssertionError:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except AssertionError:
return string
def remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
def fix_sqrt(string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
def strip_string(string):
# linebreaks
string = string.replace("\n", "")
# remove inverse spaces
string = string.replace("\\!", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "") # noqa: W605
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
# fix sqrt3 --> sqrt{3}
string = fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = fix_a_slash_b(string)
return string
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment