Commit 9f3655ac authored by Baber's avatar Baber
Browse files

add math_verify

parent ccf4a58a
...@@ -2,6 +2,7 @@ dataset_name: main ...@@ -2,6 +2,7 @@ dataset_name: main
dataset_path: gsm8k dataset_path: gsm8k
doc_to_target: '{{answer.split(''####'')[-1].strip() if answer is defined else target}}' doc_to_target: '{{answer.split(''####'')[-1].strip() if answer is defined else target}}'
doc_to_text: "Given the following problem, reason and give a final answer to the problem.\nProblem: {{question}}\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n" doc_to_text: "Given the following problem, reason and give a final answer to the problem.\nProblem: {{question}}\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n"
process_results: !function utils.process_results
fewshot_config: fewshot_config:
sampler: first_n sampler: first_n
samples: samples:
...@@ -41,19 +42,6 @@ fewshot_config: ...@@ -41,19 +42,6 @@ fewshot_config:
she have left? she have left?
target: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 target: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15
dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The final answer is 8 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The final answer is 8
filter_list:
- filter:
- function: regex
group_select: -1
regex_pattern: The final answer is ((-?[$0-9.,]{2,})|(-?[0-9]+))
- function: take_first
name: strict-match
- filter:
- function: regex
group_select: -1
regex_pattern: (-?[$0-9.,]{2,})|(-?[0-9]+)
- function: take_first
name: flexible-extract
generation_kwargs: generation_kwargs:
do_sample: false do_sample: false
until: until:
......
...@@ -7,19 +7,13 @@ output_type: generate_until ...@@ -7,19 +7,13 @@ output_type: generate_until
training_split: train training_split: train
fewshot_split: train fewshot_split: train
test_split: test test_split: test
process_results: !function utils.process_results
doc_to_text: "Q: {{question}}\nA: Let's think step by step." doc_to_text: "Q: {{question}}\nA: Let's think step by step."
doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}" doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}"
metric_list: metric_list:
- metric: exact_match - metric: math_verify
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
ignore_case: true
ignore_punctuation: false
regexes_to_ignore:
- ","
- "\\$"
- "(?s).*#### "
- "\\.$"
generation_kwargs: generation_kwargs:
until: until:
- "Q:" - "Q:"
...@@ -28,17 +22,5 @@ generation_kwargs: ...@@ -28,17 +22,5 @@ generation_kwargs:
do_sample: false do_sample: false
repeats: 1 repeats: 1
num_fewshot: 0 num_fewshot: 0
filter_list:
- name: "strict-match"
filter:
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]+)."
- function: "take_first"
- name: "flexible-extract"
filter:
- function: "regex"
group_select: -1
regex_pattern: "(-?[$0-9.,]{2,})|(-?[0-9]+)"
- function: "take_first"
metadata: metadata:
version: 3.0 version: 3.0
...@@ -43,18 +43,6 @@ fewshot_config: ...@@ -43,18 +43,6 @@ fewshot_config:
she have left? she have left?
target: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 target: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15
dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8. dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8.
filter_list:
- filter:
- function: regex
regex_pattern: The answer is (\-?[0-9\.\,]+).
- function: take_first
name: strict-match
- filter:
- function: regex
group_select: -1
regex_pattern: (-?[$0-9.,]{2,})|(-?[0-9]+)
- function: take_first
name: flexible-extract
generation_kwargs: generation_kwargs:
do_sample: false do_sample: false
until: until:
...@@ -66,16 +54,9 @@ tag: ...@@ -66,16 +54,9 @@ tag:
metadata: metadata:
version: 3.0 version: 3.0
metric_list: metric_list:
- aggregation: mean - aggregation: math_verify
higher_is_better: true higher_is_better: true
ignore_case: true ignore_case: true
ignore_punctuation: false
metric: exact_match
regexes_to_ignore:
- ','
- \$
- '(?s).*#### '
- \.$
num_fewshot: 8 num_fewshot: 8
output_type: generate_until output_type: generate_until
repeats: 1 repeats: 1
......
...@@ -7,19 +7,13 @@ output_type: generate_until ...@@ -7,19 +7,13 @@ output_type: generate_until
training_split: train training_split: train
fewshot_split: train fewshot_split: train
test_split: test test_split: test
process_results: !function utils.process_results
doc_to_text: "Question: {{question}}\nAnswer:" doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}" doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}"
metric_list: metric_list:
- metric: exact_match - metric: math_verify
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
ignore_case: true
ignore_punctuation: false
regexes_to_ignore:
- ","
- "\\$"
- "(?s).*#### "
- "\\.$"
generation_kwargs: generation_kwargs:
until: until:
- "Question:" - "Question:"
...@@ -29,17 +23,5 @@ generation_kwargs: ...@@ -29,17 +23,5 @@ generation_kwargs:
temperature: 0.0 temperature: 0.0
repeats: 1 repeats: 1
num_fewshot: 5 num_fewshot: 5
filter_list:
- name: "strict-match"
filter:
- function: "regex"
regex_pattern: "#### (\\-?[0-9\\.\\,]+)"
- function: "take_first"
- name: "flexible-extract"
filter:
- function: "regex"
group_select: -1
regex_pattern: "(-?[$0-9.,]{2,})|(-?[0-9]+)"
- function: "take_first"
metadata: metadata:
version: 3.0 version: 3.0
from typing import Dict, List
from math_verify import parse, verify
def process_results(doc: dict, results: List[str]) -> Dict[str, int]:
candidates = results[0]
# math_verify
res = verify(parse(doc["answer"]), parse(candidates))
mathval = 1 if res else 0
results = {
"math_verify": mathval,
}
return results
...@@ -3,6 +3,7 @@ import signal ...@@ -3,6 +3,7 @@ import signal
from typing import Dict, List, Optional from typing import Dict, List, Optional
import datasets import datasets
from math_verify import parse, verify
from lm_eval.utils import eval_logger from lm_eval.utils import eval_logger
...@@ -75,8 +76,13 @@ def process_results(doc: dict, results: List[str]) -> Dict[str, int]: ...@@ -75,8 +76,13 @@ def process_results(doc: dict, results: List[str]) -> Dict[str, int]:
else: else:
retval = 0 retval = 0
# math_verify
res = verify(parse(doc["answer"]), parse(candidates))
mathval = 1 if res else 0
results = { results = {
"exact_match": retval, "exact_match": retval,
"math_verify": mathval,
} }
return results return results
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment