Unverified Commit c7173ca3 authored by sdtblck's avatar sdtblck Committed by GitHub
Browse files

fix some things in xquad

parent 76aa4c8a
......@@ -2,6 +2,7 @@ from .squad import SQuAD2
from math import exp
from functools import partial
import datasets
from lm_eval.base import rf
def _squad_metric(predictions, references):
......@@ -11,6 +12,9 @@ def _squad_metric(predictions, references):
def _squad_agg(key, items):
predictions, references = zip(*items)
for prediction in predictions:
if isinstance(prediction['prediction_text'], list):
prediction['prediction_text'] = prediction['prediction_text'][0]
return _squad_metric(predictions=predictions, references=references)[key]
......@@ -24,6 +28,20 @@ class XQuADBase(SQuAD2):
def has_training_docs(self):
return False
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = rf.greedy_until(ctx, ['\n'])
return continuation
def doc_to_text(self, doc):
text = self.BACKGROUND + '\n\n' + doc['context'] + '\n\n' + self.QUESTION + doc['question'] + '\n\n' + \
......@@ -40,14 +58,11 @@ class XQuADBase(SQuAD2):
:param results:
The results of the requests created in construct_requests.
"""
continuation, (logprob_unanswerable, _) = results
no_answer_probability = exp(logprob_unanswerable)
continuation = results
predictions = {
'id': doc['id'],
'prediction_text': continuation,
'no_answer_probability': no_answer_probability,
}
references = {
......@@ -56,7 +71,7 @@ class XQuADBase(SQuAD2):
}
return {
'exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'exact_match': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
}
......@@ -67,7 +82,7 @@ class XQuADBase(SQuAD2):
functions that aggregate a list of metrics
"""
return {
'exact': partial(_squad_agg, 'exact'), # Exact match (the normalized answer exactly match the gold answer)
'exact_match': partial(_squad_agg, 'exact_match'), # Exact match (the normalized answer exactly match the gold answer)
'f1': partial(_squad_agg, 'f1'), # The F-score of predicted tokens versus the gold answer
}
......@@ -78,7 +93,7 @@ class XQuADBase(SQuAD2):
whether a higher value of the submetric is better
"""
return {
'exact': True, # Exact match (the normalized answer exactly match the gold answer)
'exact_match': True, # Exact match (the normalized answer exactly match the gold answer)
'f1': True, # The F-score of predicted tokens versus the gold answer
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment