Unverified Commit fe3040f1 authored by am-bean's avatar am-bean Committed by GitHub
Browse files

LingOly - Fixing scoring bugs for smaller models (#2376)

* Fixing scoring bugs for smaller models

* Catching another error type in parsing
parent 8f619361
# Task-name
LingOly
# LingOly
### Paper
......@@ -27,21 +26,11 @@ Homepage: `https://github.com/am-bean/lingOly`
}
```
### Groups, Tags, and Tasks
### Tasks
#### Groups
* `group_name`: `Short description`
#### Tags
* `reasoning`: ``
* `linguistics`: ``
#### Tasks
* `exact_match`: `exact match of generations to reference`
* `delta_nc`: `improvement in score relative to no-context baseline`
* `lingoly`: `runs both _context and _nocontext and computes the difference`
* `lingoly_context`: `exact match of generations to reference answers`
* `lingoly_nocontext`: `exact match of generations to reference answers, but with context removed`
### Checklist
......
......@@ -9,6 +9,13 @@ validation_split: test
test_split: test
fewshot_split: null
generation_kwargs:
until:
- "}\n"
max_gen_toks: 512
do_sample: false
temperature: 0.0
process_docs: !function utils.load_all_questions
doc_to_text: prompt
......
......@@ -9,6 +9,13 @@ validation_split: test
test_split: test
fewshot_split: null
generation_kwargs:
until:
- "}\n"
max_gen_toks: 512
do_sample: false
temperature: 0.0
process_docs: !function utils.load_all_questions
doc_to_text: nc_prompt
......
......@@ -45,13 +45,10 @@ def parse_str_list_score(model, correct, scoring_func):
return 1.0
if len(model) == 0:
return 0.0
if "[" in correct:
try:
readstr = ast.literal_eval(correct)
if isinstance(readstr, list):
correct = readstr
except SyntaxError:
pass
if ("[" in correct) and (("'" in correct) or ('"' in correct)):
readstr = ast.literal_eval(correct)
if isinstance(readstr, list):
correct = readstr
if isinstance(correct, list):
if all(isinstance(c, str) for c in correct):
max_score = 0.0
......@@ -91,21 +88,31 @@ def parse_str_list_score(model, correct, scoring_func):
)
def exact_match(input):
ref_dict = ast.literal_eval(input[0])
def exact_match(references: list[str], predictions: list[str]):
ref_dict = ast.literal_eval(references[0])
try:
pred_dict = ast.literal_eval(input[1])
except SyntaxError:
assert "{" in predictions[0]
if predictions[0][-1] == "}":
pred_dict = ast.literal_eval(predictions[0][predictions[0].index("{") :])
else:
pred_dict = ast.literal_eval(
predictions[0][predictions[0].index("{") :] + "}"
)
except (SyntaxError, ValueError, AssertionError):
pred_dict = {}
for k in ref_dict.keys():
m = re.search(str(k) + "': ([^']+)'[,\\}]", input[1])
m = re.search(re.escape(str(k)) + """': ([^']+)'[,\\}]""", predictions[0])
n = re.search(re.escape(str(k)) + """": ([^"]+)"[,\\}]""", predictions[0])
if m:
pred_dict[k] = m.group()[:-1]
elif n:
pred_dict[k] = n.group()[:-1]
else:
pred_dict[k] = ""
pred_dict_full = {
k: pred_dict[k] if k in pred_dict else "" for k in ref_dict.keys()
}
scores = [
parse_str_list_score(pred_dict_full[k], v, safe_exact)
for k, v in ref_dict.items()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment