Commit d7c08d5b authored by Leo Gao's avatar Leo Gao
Browse files

Convert SAT to use MultipleChoiceTask

parent 2f5f42c6
...@@ -61,8 +61,8 @@ class SATAnalogies(Task): ...@@ -61,8 +61,8 @@ class SATAnalogies(Task):
doc = { doc = {
'source': source, 'source': source,
'query': query.split(' ')[:2], 'query': query.split(' ')[:2],
'choices': [c.split(' ')[:2] for c in choices], 'choices': [" {} is to {}".format(*c.split(' ')[:2]) for c in choices],
'answer_key': ['a','b','c','d','e'].index(answer_key.strip()), 'gold': ['a','b','c','d','e'].index(answer_key.strip()),
} }
yield doc yield doc
...@@ -72,35 +72,4 @@ class SATAnalogies(Task): ...@@ -72,35 +72,4 @@ class SATAnalogies(Task):
return "" return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{} is to {} as ".format(*doc['query']) return "{} is to {} as".format(*doc['query'])
def doc_to_target(self, doc):
return "{} is to {}".format(*doc['choices'][doc['answer_key']])
def construct_requests(self, doc, ctx):
lls = [
rf.loglikelihood(ctx, "{} is to {}".format(*doc['choices'][i]))[0]
for i in range(5)
]
return lls
def process_results(self, doc, results):
gold = doc["answer_key"]
acc = 1. if np.argmax(results) == gold else 0.
return {
"acc": acc
}
def higher_is_better(self):
return {
"acc": True
}
def aggregation(self):
return {
"acc": mean
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment