Commit 960a0e39 authored by jon-tow's avatar jon-tow
Browse files

Update question bundle op to write out options

parent 0ec1b553
...@@ -15,6 +15,7 @@ from lm_eval.metrics import mean ...@@ -15,6 +15,7 @@ from lm_eval.metrics import mean
@dataclass @dataclass
class MultipleChoiceDoc: class MultipleChoiceDoc:
""" Structure for storing documents. """
question: str question: str
keys: typing.List[str] # Should these be the same type as gold? keys: typing.List[str] # Should these be the same type as gold?
options: typing.List[str] options: typing.List[str]
...@@ -27,6 +28,7 @@ class MultipleChoiceDoc: ...@@ -27,6 +28,7 @@ class MultipleChoiceDoc:
class BaseMultipleChoiceTask(base.Task, abc.ABC): class BaseMultipleChoiceTask(base.Task, abc.ABC):
""" Base Multiple Choice Task """
def doc_to_text(self, doc: MultipleChoiceDoc): def doc_to_text(self, doc: MultipleChoiceDoc):
ctx = f"{doc.context}\n" if doc.context else "" ctx = f"{doc.context}\n" if doc.context else ""
...@@ -66,7 +68,7 @@ class BaseMultipleChoiceTask(base.Task, abc.ABC): ...@@ -66,7 +68,7 @@ class BaseMultipleChoiceTask(base.Task, abc.ABC):
# Bundle answers: (model_answer, model_answer_index, is_correct, question_id). # Bundle answers: (model_answer, model_answer_index, is_correct, question_id).
"answer_bundle": (doc.keys[ans], ans, is_correct, doc.id), "answer_bundle": (doc.keys[ans], ans, is_correct, doc.id),
# Bundle questions: (question_id, question, option_0, option_1, option_2, option_3) # Bundle questions: (question_id, question, option_0, option_1, option_2, option_3)
#"question_bundle": (doc.id, doc.question, len(doc.options)), #"question_bundle": (doc.id, doc.question, doc.options),
} }
def higher_is_better(self): def higher_is_better(self):
...@@ -93,7 +95,7 @@ def answer_bundle(items): ...@@ -93,7 +95,7 @@ def answer_bundle(items):
cols = ["model_answer", "model_answer_index", "is_correct", "question_id"] cols = ["model_answer", "model_answer_index", "is_correct", "question_id"]
rows = [*items] rows = [*items]
path = os.environ["QUESTION_RESULT_PATH"] path = os.environ["QUESTION_RESULT_PATH"]
with open(f'{path}/question-by-question-results.csv', 'a') as f: with open(f'{path}/question-by-question-results.csv', 'a', encoding="utf-8") as f:
write = csv.writer(f) write = csv.writer(f)
write.writerow(cols) write.writerow(cols)
write.writerows(rows) write.writerows(rows)
...@@ -104,27 +106,34 @@ def question_bundle(items): ...@@ -104,27 +106,34 @@ def question_bundle(items):
""" Bundles questions into a csv file. """ """ Bundles questions into a csv file. """
from pathlib import Path from pathlib import Path
import csv import csv
num_options = items[0][2] options = items[0][2]
options = [f"option_{i}" for i in range(num_options)] options_name = [f"option_{i}" for i in range(len(options))]
cols = ["question_id","question", *options] cols = ["question_id","question", *options_name]
rows = [*items]
path = os.environ["QUESTION_RESULT_PATH"] path = os.environ["QUESTION_RESULT_PATH"]
with open(f'{path}/question-table.csv', 'a') as f: f = open(f'{path}/question-table.csv', 'a', encoding="utf-8")
write = csv.writer(f) writer = csv.writer(f)
write.writerow(cols) writer.writerow(cols)
write.writerows(rows) for item in items:
writer.writerow([item[0],item[1],*item[2]])
f.close()
return 0 return 0
def key2num(doc: MultipleChoiceDoc, key: str) -> int: def key2num(doc: MultipleChoiceDoc, key: str) -> int:
""" Maps document keys to numeric 1-based indices. """
return str(doc.keys.index(key) + 1) # `+ 1` for 1-based indexing. return str(doc.keys.index(key) + 1) # `+ 1` for 1-based indexing.
def key2letter(doc: MultipleChoiceDoc, key: str) -> str: def key2letter(doc: MultipleChoiceDoc, key: str) -> str:
""" Maps keys to capital alphabet letters. """
A_ascii = 65 A_ascii = 65
ascii_offset = doc.keys.index(key) ascii_offset = doc.keys.index(key)
letter = chr(A_ascii + ascii_offset) letter = chr(A_ascii + ascii_offset)
return letter return letter
def format_key(key: str, type: str): def format_key(key: str, type: str):
""" Formats a multiple choice key. E.g. """ Formats a multiple choice key. E.g.
format_key("A", "period") => "A." format_key("A", "period") => "A."
...@@ -152,6 +161,7 @@ class MC_NoOptionList_OptionLL_Task(BaseMultipleChoiceTask): ...@@ -152,6 +161,7 @@ class MC_NoOptionList_OptionLL_Task(BaseMultipleChoiceTask):
Continuation: Continuation:
loglikelihood_continuation = <option_i> loglikelihood_continuation = <option_i>
""" """
def format_prompt(cls, doc: MultipleChoiceDoc) -> str: def format_prompt(cls, doc: MultipleChoiceDoc) -> str:
prompt = "Question: " + doc.question + "\n" prompt = "Question: " + doc.question + "\n"
prompt += "Answer:" prompt += "Answer:"
...@@ -212,7 +222,6 @@ class MC_WithOptionList_LetterLL_Task(BaseMultipleChoiceTask): ...@@ -212,7 +222,6 @@ class MC_WithOptionList_LetterLL_Task(BaseMultipleChoiceTask):
]) ])
prompt += "\nAnswer:" prompt += "\nAnswer:"
return prompt return prompt
def doc_to_target(self, doc: MultipleChoiceDoc) -> str: def doc_to_target(self, doc: MultipleChoiceDoc) -> str:
return " " + doc.keys[doc.gold] return " " + doc.keys[doc.gold]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment