Commit 30430b60 authored by Baber's avatar Baber
Browse files

prototype JudgeTask

parent 5d43bb4b
import json
import datasets
from lm_eval.api.task import ConfigurableTask
class JudgeTask(ConfigurableTask):
def __init__(self, config, output_path):
super().__init__(config)
self.output_path = output_path
def process_docs(self, dataset: datasets.Dataset):
resps = []
# load json
with open(self.output_path, "r") as f:
resp = json.load(f)
resps.append({"resp": resp["resps"], "doc": resp["doc"]})
resps.sort(key=lambda x: x["doc"])
dataset.add_column("resp", resps)
return resps
......@@ -7,6 +7,7 @@ from typing import Dict, List, Mapping, Optional, Union
from lm_eval import utils
from lm_eval.api.group import ConfigurableGroup, GroupConfig
from lm_eval.api.judge_task import JudgeTask
from lm_eval.api.task import ConfigurableTask, Task
from lm_eval.evaluator_utils import get_subtask_list
......@@ -272,6 +273,11 @@ class TaskManager:
if isinstance(task_object, ConfigurableTask):
# very scuffed: set task name here. TODO: fixme?
task_object.config.task = config["task"]
else:
if "resp_to_doc" in config:
task_object = JudgeTask(
config=config, output_path=config.get("output_path", None)
)
else:
task_object = ConfigurableTask(config=config)
......
tag:
- judge
include: gsm8k_judge_1.yaml
task: gsm8k_judge_2
output_type: generate_until
output_path:
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: false
regexes_to_ignore:
- ","
- "\\$"
- "(?s).*#### "
- "\\.$"
generation_kwargs:
until:
- "Question:"
- "</s>"
- "<|im_end|>"
do_sample: false
temperature: 0.0
repeats: 1
num_fewshot: 5
filter_list:
- name: "strict-match"
filter:
- function: "regex"
regex_pattern: "#### (\\-?[0-9\\.\\,]+)"
- function: "take_first"
- name: "flexible-extract"
filter:
- function: "regex"
group_select: -1
regex_pattern: "(-?[$0-9.,]{2,})|(-?[0-9]+)"
- function: "take_first"
metadata:
version: 3.0
tag:
- math_word_problems
task: gsm8k
dataset_path: gsm8k
dataset_name: main
output_type: generate_until
training_split: train
fewshot_split: train
test_split: test
doc_to_text: "Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem\n\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}"
metric_list:
- metric: bypass
aggregation: mean
generation_kwargs:
until:
- "Question:"
- "</s>"
- "<|im_end|>"
do_sample: false
temperature: 0.0
repeats: 1
num_fewshot: 5
#filter_list:
# - name: "strict-match"
# filter:
# - function: "regex"
# regex_pattern: "#### (\\-?[0-9\\.\\,]+)"
# - function: "take_first"
# - name: "flexible-extract"
# filter:
# - function: "regex"
# group_select: -1
# regex_pattern: "(-?[$0-9.,]{2,})|(-?[0-9]+)"
# - function: "take_first"
metadata:
version: 3.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment