Commit b2b5a122 authored by Leo Gao's avatar Leo Gao
Browse files

Handle aliases correctly in TriviaQA

parent 0601d0bb
...@@ -40,16 +40,30 @@ class TriviaQA(Dataset): ...@@ -40,16 +40,30 @@ class TriviaQA(Dataset):
return ''.join(['Q:', doc['Question'], '\n\n','A:']) return ''.join(['Q:', doc['Question'], '\n\n','A:'])
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc['Answer']['Aliases'][0] return " " + doc['Answer']['Value']
def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases.sort()
ret = [aliases[0]]
for alias in aliases[1:]:
if not alias.startswith(ret[-1]):
ret.append(alias)
return ret
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll, is_prediction = rf.loglikelihood(ctx,doc['Answer']['Value']) ret = []
return is_prediction for alias in self._remove_prefixes(doc['Answer']['Aliases']):
_, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction)
return ret
def process_results(self, doc, results): def process_results(self, doc, results):
is_prediction = results
return { return {
"acc": float(is_prediction[1]) "acc": float(any(results))
} }
def aggregation(self): def aggregation(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment