Commit 478d96cd authored by Baber's avatar Baber
Browse files

test out gsm8k

parent 2275b1c4
...@@ -44,7 +44,10 @@ class JudgeTask(ConfigurableTask): ...@@ -44,7 +44,10 @@ class JudgeTask(ConfigurableTask):
self.dataset["test"] = self.dataset["test"].add_column( self.dataset["test"] = self.dataset["test"].add_column(
"resp", [resp["resp"] for resp in resps] "resp", [resp["resp"] for resp in resps]
) )
print("done") self.dataset["train"] = self.dataset["train"].add_column(
"resp", self.dataset["train"]["answer"]
)
print("resp columns added")
# def process_docs(self, dataset: datasets.Dataset): # def process_docs(self, dataset: datasets.Dataset):
# resps = [] # resps = []
......
...@@ -6,8 +6,10 @@ output_type: generate_until ...@@ -6,8 +6,10 @@ output_type: generate_until
output_path: output_path:
doc_to_text: "Question: {{question}}\nAnswer:" #doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}" doc_to_text: 'Given the following question and reference answer, verify if the attempted answer is correct. If it is, return "The answer is Correct". If it is incorrect, return "The answer is Incorrect".\nQuestion: {{question}}\nReference Answer: {{answer}}\nAnswer Attempt: {{resp}}'
target_delimiter: "\n"
doc_to_target: "The answer is Correct" #" {{answer.split('### ')[-1].rstrip()}}"
metric_list: metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
...@@ -21,6 +23,7 @@ metric_list: ...@@ -21,6 +23,7 @@ metric_list:
- "\\.$" - "\\.$"
generation_kwargs: generation_kwargs:
until: until:
- '<|start_header_id|>user<|end_header_id|>'
- "Question:" - "Question:"
- "</s>" - "</s>"
- "<|im_end|>" - "<|im_end|>"
...@@ -29,16 +32,12 @@ generation_kwargs: ...@@ -29,16 +32,12 @@ generation_kwargs:
repeats: 1 repeats: 1
num_fewshot: 5 num_fewshot: 5
filter_list: filter_list:
- name: "strict-match" - name: "test"
filter: filter:
- function: "regex" - function: "regex"
regex_pattern: "#### (\\-?[0-9\\.\\,]+)" regex_pattern: "The answer is (Correct|Incorrect)"
- function: "take_first" ignore_punctuation: true
- name: "flexible-extract" ignore_case: true
filter:
- function: "regex"
group_select: -1
regex_pattern: "(-?[$0-9.,]{2,})|(-?[0-9]+)"
- function: "take_first" - function: "take_first"
metadata: metadata:
version: 3.0 version: 3.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment