Unverified Commit a3f10268 authored by sdtblck's avatar sdtblck Committed by GitHub
Browse files

Update mlqa.py

parent c7173ca3
...@@ -170,7 +170,7 @@ def mlqa_metric(predictions, references, answer_lang): ...@@ -170,7 +170,7 @@ def mlqa_metric(predictions, references, answer_lang):
return evaluate(dataset, pred_dict, answer_lang) return evaluate(dataset, pred_dict, answer_lang)
def mlqa_agg(key, items, answer_lang): def mlqa_agg(items, key, answer_lang):
predictions, references = zip(*items) predictions, references = zip(*items)
return mlqa_metric(predictions=predictions, references=references, answer_lang=answer_lang)[key] return mlqa_metric(predictions=predictions, references=references, answer_lang=answer_lang)[key]
...@@ -218,8 +218,7 @@ class MLQABase(HFTask): ...@@ -218,8 +218,7 @@ class MLQABase(HFTask):
part of the document for `doc`. part of the document for `doc`.
""" """
continuation = rf.greedy_until(ctx, ['\n']) continuation = rf.greedy_until(ctx, ['\n'])
is_unanswerable = rf.loglikelihood(ctx, " " + "unanswerable") return continuation
return continuation, is_unanswerable
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -233,12 +232,9 @@ class MLQABase(HFTask): ...@@ -233,12 +232,9 @@ class MLQABase(HFTask):
""" """
continuation, (logprob_unanswerable, _) = results continuation, (logprob_unanswerable, _) = results
no_answer_probability = exp(logprob_unanswerable)
predictions = { predictions = {
'id': doc['id'], 'id': doc['id'],
'prediction_text': continuation, 'prediction_text': continuation,
'no_answer_probability': no_answer_probability,
} }
references = { references = {
...@@ -247,7 +243,7 @@ class MLQABase(HFTask): ...@@ -247,7 +243,7 @@ class MLQABase(HFTask):
} }
return { return {
'exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer) 'exact_match': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'f1': (predictions, references), # The F-score of predicted tokens versus the gold answer 'f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
} }
...@@ -258,7 +254,7 @@ class MLQABase(HFTask): ...@@ -258,7 +254,7 @@ class MLQABase(HFTask):
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {
'exact': partial(mlqa_agg, key='exact', answer_lang=self.ANSWER_LANG), # Exact match (the normalized 'exact_match': partial(mlqa_agg, key='exact_match', answer_lang=self.ANSWER_LANG), # Exact match (the normalized
# answer exactly match the gold answer) # answer exactly match the gold answer)
'f1': partial(mlqa_agg, key='f1', answer_lang=self.ANSWER_LANG), # The F-score of predicted tokens 'f1': partial(mlqa_agg, key='f1', answer_lang=self.ANSWER_LANG), # The F-score of predicted tokens
# versus the gold answer # versus the gold answer
...@@ -271,7 +267,7 @@ class MLQABase(HFTask): ...@@ -271,7 +267,7 @@ class MLQABase(HFTask):
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {
'exact': True, # Exact match (the normalized answer exactly match the gold answer) 'exact_match': True, # Exact match (the normalized answer exactly match the gold answer)
'f1': True, # The F-score of predicted tokens versus the gold answer 'f1': True, # The F-score of predicted tokens versus the gold answer
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment