Commit b2b5a122 authored by Leo Gao's avatar Leo Gao
Browse files

Handle aliases correctly in TriviaQA

parent 0601d0bb
......@@ -40,16 +40,30 @@ class TriviaQA(Dataset):
return ''.join(['Q:', doc['Question'], '\n\n','A:'])
def doc_to_target(self, doc):
return doc['Answer']['Aliases'][0]
return " " + doc['Answer']['Value']
def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases.sort()
ret = [aliases[0]]
for alias in aliases[1:]:
if not alias.startswith(ret[-1]):
ret.append(alias)
return ret
def construct_requests(self, doc, ctx):
ll, is_prediction = rf.loglikelihood(ctx,doc['Answer']['Value'])
return is_prediction
ret = []
for alias in self._remove_prefixes(doc['Answer']['Aliases']):
_, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction)
return ret
def process_results(self, doc, results):
is_prediction = results
return {
"acc": float(is_prediction[1])
"acc": float(any(results))
}
def aggregation(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment