Commit c9a3fd3f authored by lintangsutawika's avatar lintangsutawika
Browse files

update squad file

parent 5a4fc8fb
...@@ -16,7 +16,7 @@ from lm_eval.api.registry import ( ...@@ -16,7 +16,7 @@ from lm_eval.api.registry import (
import logging import logging
# import python tasks # import python tasks
from .squad import SQuAD2 from .squadv2.task import SQuAD2
from .scrolls.task import ( from .scrolls.task import (
QuALITY, QuALITY,
NarrativeQA, NarrativeQA,
......
...@@ -14,6 +14,7 @@ also determine when no answer is supported by the paragraph and abstain from ans ...@@ -14,6 +14,7 @@ also determine when no answer is supported by the paragraph and abstain from ans
Homepage: https://rajpurkar.github.io/SQuAD-explorer/ Homepage: https://rajpurkar.github.io/SQuAD-explorer/
""" """
import datasets import datasets
from evaluate import load
from math import exp from math import exp
from functools import partial from functools import partial
...@@ -36,6 +37,7 @@ _CITATION = """ ...@@ -36,6 +37,7 @@ _CITATION = """
def _squad_metric(predictions, references): def _squad_metric(predictions, references):
# squad_metric = load("squad_v2")
squad_metric = datasets.load_metric("squad_v2") squad_metric = datasets.load_metric("squad_v2")
return squad_metric.compute(predictions=predictions, references=references) return squad_metric.compute(predictions=predictions, references=references)
...@@ -46,6 +48,20 @@ def _squad_agg(key, items): ...@@ -46,6 +48,20 @@ def _squad_agg(key, items):
return _squad_metric(predictions=predictions, references=references).get(key, 0) return _squad_metric(predictions=predictions, references=references).get(key, 0)
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
return re.sub(regex, ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
@register_task("squadv2") @register_task("squadv2")
class SQuAD2(Task): class SQuAD2(Task):
VERSION = 1 VERSION = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment