Unverified Commit 1e44199c authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge branch 'master' into add_lama

parents 235494f4 5e59320b
...@@ -353,11 +353,13 @@ class BaseLM(LM): ...@@ -353,11 +353,13 @@ class BaseLM(LM):
for context, request_args in tqdm(reord.get_reordered()): for context, request_args in tqdm(reord.get_reordered()):
stopping_criteria = request_args["stopping_criteria"] stopping_criteria = request_args["stopping_criteria"]
max_generation_length = request_args["max_generation_length"] max_generation_length = request_args["max_generation_length"]
num_fewshot = request_args["num_fewshot"]
assert isinstance(stopping_criteria, str) or stopping_criteria is None assert isinstance(stopping_criteria, str) or stopping_criteria is None
assert ( assert (
isinstance(max_generation_length, int) or max_generation_length is None isinstance(max_generation_length, int) or max_generation_length is None
) )
assert isinstance(num_fewshot, int) or num_fewshot is None
if stopping_criteria is None: if stopping_criteria is None:
until = [self.eot_token] until = [self.eot_token]
...@@ -382,9 +384,10 @@ class BaseLM(LM): ...@@ -382,9 +384,10 @@ class BaseLM(LM):
context_enc, context_enc,
max_length, max_length,
torch.tensor(primary_until), torch.tensor(primary_until),
num_fewshot,
) )
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :]) s = self.tok_decode(cont.tolist())
for term in until: for term in until:
s = s.split(term)[0] s = s.split(term)[0]
...@@ -536,7 +539,7 @@ class Task(abc.ABC): ...@@ -536,7 +539,7 @@ class Task(abc.ABC):
pass pass
@abstractmethod @abstractmethod
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx, args):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
...@@ -546,6 +549,8 @@ class Task(abc.ABC): ...@@ -546,6 +549,8 @@ class Task(abc.ABC):
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
:param args: dict
The specifics of the context, including number of few shots.
""" """
pass pass
...@@ -645,14 +650,18 @@ class Task(abc.ABC): ...@@ -645,14 +650,18 @@ class Task(abc.ABC):
# get rid of the doc that's the one we're evaluating, if it's in the fewshot # get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot] fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
# See Webson & Pavlick (2022) https://arxiv.org/pdf/2109.01247.pdf
# for justification of this separator.
example_separator = "\n###\n"
labeled_examples = ( labeled_examples = (
"\n\n".join( example_separator.join(
[ [
self.doc_to_text(doc) + self.doc_to_target(doc) self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex for doc in fewshotex
] ]
) )
+ "\n\n" + example_separator
) )
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
...@@ -720,7 +729,7 @@ class PromptSourceTask(Task): ...@@ -720,7 +729,7 @@ class PromptSourceTask(Task):
text, _ = self.prompt.apply(doc) text, _ = self.prompt.apply(doc)
return text return text
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx, args):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
...@@ -730,6 +739,8 @@ class PromptSourceTask(Task): ...@@ -730,6 +739,8 @@ class PromptSourceTask(Task):
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
:param args: dict
The specifics of the context, including number of few shots.
""" """
_requests = [] _requests = []
answer_choices_list = self.prompt.get_answer_choices_list(doc) answer_choices_list = self.prompt.get_answer_choices_list(doc)
...@@ -745,6 +756,7 @@ class PromptSourceTask(Task): ...@@ -745,6 +756,7 @@ class PromptSourceTask(Task):
request_args = { request_args = {
"stopping_criteria": self.stopping_criteria(), "stopping_criteria": self.stopping_criteria(),
"max_generation_length": self.max_generation_length(), "max_generation_length": self.max_generation_length(),
"num_fewshot": args["num_fewshot"],
} }
cont_request = rf.greedy_until(ctx, request_args) cont_request = rf.greedy_until(ctx, request_args)
_requests.append(cont_request) _requests.append(cont_request)
...@@ -911,12 +923,12 @@ class PromptSourceTask(Task): ...@@ -911,12 +923,12 @@ class PromptSourceTask(Task):
if num_fewshot == 0: if num_fewshot == 0:
labeled_examples = "" labeled_examples = ""
fewshotex, fewshotidx, fewshotsource = [], [], None fewshotex, fewshotidx, self.fewshotsource = [], [], None
else: else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc* # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs(): if self.has_training_docs():
fewshotex, fewshotidx = self.fewshot_examples(k=num_fewshot, rnd=rnd) fewshotex, fewshotidx = self.fewshot_examples(k=num_fewshot, rnd=rnd)
fewshotsource = "train" self.fewshotsource = "train"
else: else:
if self._fewshot_docs is None: if self._fewshot_docs is None:
self._fewshot_docs = list( self._fewshot_docs = list(
...@@ -925,32 +937,35 @@ class PromptSourceTask(Task): ...@@ -925,32 +937,35 @@ class PromptSourceTask(Task):
else self.test_docs() else self.test_docs()
) )
if self.has_validation_docs(): if self.has_validation_docs():
fewshotsource = "val" self.fewshotsource = "val"
elif self.test_docs(): elif self.test_docs():
fewshotsource = "test" self.fewshotsource = "test"
fewshotex, fewshotidx = self._get_fewshot_examples( fewshotex, fewshotidx = self._get_fewshot_examples(
self._fewshot_docs, k=num_fewshot + 1, rnd=rnd self._fewshot_docs, k=num_fewshot + 1, rnd=rnd
) )
fewshotex, fewshotidx = [ fewshotex, fewshotidx = zip(*[
(shot, idx) (shot, idx)
for shot, idx in zip(fewshotex, fewshotidx) for shot, idx in zip(fewshotex, fewshotidx)
if shot != doc if shot != doc
] ])
# get rid of the doc that's the one we're evaluating, if it's in the fewshot # get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex, fewshotidx = ( fewshotex, fewshotidx = (
fewshotex[:num_fewshot], fewshotex[:num_fewshot],
fewshotidx[:num_fewshot], fewshotidx[:num_fewshot],
) )
# See Webson & Pavlick (2022) https://arxiv.org/pdf/2109.01247.pdf
# for justification of this separator.
example_separator = "\n###\n"
labeled_examples = ( labeled_examples = (
"\n\n".join( example_separator.join(
[ [
self.doc_to_text(doc) + self.doc_to_target(doc) self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex for doc in fewshotex
] ]
) )
+ "\n\n" + example_separator
) )
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
...@@ -959,7 +974,7 @@ class PromptSourceTask(Task): ...@@ -959,7 +974,7 @@ class PromptSourceTask(Task):
ctx, ctx,
{ {
"fewshot_idx": fewshotidx, "fewshot_idx": fewshotidx,
"fewshot_source": fewshotsource, "fewshot_source": self.fewshotsource,
"fewshot_num": num_fewshot, "fewshot_num": num_fewshot,
"ctx": ctx, "ctx": ctx,
}, },
......
...@@ -206,7 +206,8 @@ def evaluate( ...@@ -206,7 +206,8 @@ def evaluate(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
) )
fewshotex_logging_info["doc_id"] = original_doc_id fewshotex_logging_info["doc_id"] = original_doc_id
reqs = task.construct_requests(doc, ctx) args = {"num_fewshot": num_fewshot}
reqs = task.construct_requests(doc, ctx, args)
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
reqs = [reqs] reqs = [reqs]
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
......
...@@ -12,6 +12,7 @@ class HFLM(BaseLM): ...@@ -12,6 +12,7 @@ class HFLM(BaseLM):
subfolder=None, subfolder=None,
tokenizer=None, tokenizer=None,
batch_size=1, batch_size=1,
parallelize=False
): ):
super().__init__() super().__init__()
...@@ -32,7 +33,7 @@ class HFLM(BaseLM): ...@@ -32,7 +33,7 @@ class HFLM(BaseLM):
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained( self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""), revision=revision + ("/" + subfolder if subfolder is not None else ""),
).to(self.device) )
self.gpt2.eval() self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2 # pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
...@@ -68,9 +69,11 @@ class HFLM(BaseLM): ...@@ -68,9 +69,11 @@ class HFLM(BaseLM):
self.batch_size_per_gpu = batch_size # todo: adaptive batch size self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu # TODO: fix multi-gpu
# gpus = torch.cuda.device_count() if parallelize:
# if gpus > 1: self.gpt2.parallelize()
# self.gpt2 = nn.DataParallel(self.gpt2) self._device = torch.device('cuda:0')
else:
self.gpt2.to(self._device)
@property @property
def eot_token(self): def eot_token(self):
...@@ -146,16 +149,26 @@ class HFLM(BaseLM): ...@@ -146,16 +149,26 @@ class HFLM(BaseLM):
EOSCriteria(self.tokenizer.eos_token) EOSCriteria(self.tokenizer.eos_token)
]) ])
def _model_generate(self, context, max_length, stopping_criteria_ids): def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.gpt2.generate(
if num_fewshot == 0:
generations = self.gpt2.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.gpt2.generate(
context, context,
max_length=max_length, max_length=max_length,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
do_sample=False, do_sample=False,
) )
# Remove the context from the generations
return generations[0, context.shape[1] :]
# for backwards compatibility # for backwards compatibility
GPT2LM = HFLM GPT2LM = HFLM
...@@ -39,6 +39,10 @@ class GPTJLM(BaseLM): ...@@ -39,6 +39,10 @@ class GPTJLM(BaseLM):
# if gpus > 1: # if gpus > 1:
# self.gptj = nn.DataParallel(self.gptj) # self.gptj = nn.DataParallel(self.gptj)
@property
def eot_token(self):
return self.tokenizer.eos_token
@property @property
def eot_token_id(self): def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
......
...@@ -62,7 +62,7 @@ class T5LM(BaseLM): ...@@ -62,7 +62,7 @@ class T5LM(BaseLM):
@property @property
def max_gen_toks(self): def max_gen_toks(self):
return self.tokenizer.model_max_length return 256
@property @property
def batch_size(self): def batch_size(self):
...@@ -100,6 +100,14 @@ class T5LM(BaseLM): ...@@ -100,6 +100,14 @@ class T5LM(BaseLM):
inputs, targets = zip(*chunk) inputs, targets = zip(*chunk)
# Fill in empty encoder inputs with eos_token
inputs = (
f"{self.eot_token}"
if len(input_) == 0
else input_
for input_ in inputs
)
inputs_tok = self.tokenizer( inputs_tok = self.tokenizer(
list(inputs), list(inputs),
max_length=self.max_length, max_length=self.max_length,
...@@ -178,11 +186,21 @@ class T5LM(BaseLM): ...@@ -178,11 +186,21 @@ class T5LM(BaseLM):
EOSCriteria(self.tokenizer.eos_token) EOSCriteria(self.tokenizer.eos_token)
]) ])
def _model_generate(self, context, max_length, stopping_criteria_ids): def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.t5.generate(
if num_fewshot == 0:
generations = self.t5.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.t5.generate(
context, context,
max_length=max_length, max_length=max_length,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
do_sample=False, do_sample=False,
) )
return generations[0]
...@@ -56,6 +56,11 @@ from . import hans ...@@ -56,6 +56,11 @@ from . import hans
from . import gem_webnlg from . import gem_webnlg
from . import lama from . import lama
# from . import e2e_nlg_cleaned # from . import e2e_nlg_cleaned
from . import gem_xsum
from . import gem_mlsum
from . import wino_bias
from . import e2e_nlg_cleaned
from . import gem_asset_turk
######################################## ########################################
# Translation tasks # Translation tasks
...@@ -109,11 +114,12 @@ TASK_REGISTRY = { ...@@ -109,11 +114,12 @@ TASK_REGISTRY = {
"wsc": superglue.SGWinogradSchemaChallenge, "wsc": superglue.SGWinogradSchemaChallenge,
# Order by benchmark/genre? # Order by benchmark/genre?
"coqa": coqa.CoQA, "coqa": coqa.CoQA,
"GEM/web_nlg": gem_webnlg.WebNLG,
"drop": drop.DROP, "drop": drop.DROP,
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze, "lambada_cloze": lambada_cloze.LAMBADA_cloze,
**gem_webnlg.construct_tasks(),
# multilingual lambada # multilingual lambada
**gem_asset_turk.construct_tasks(),
**lambada_multilingual.construct_tasks(), **lambada_multilingual.construct_tasks(),
"wikitext": wikitext.WikiText, "wikitext": wikitext.WikiText,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix # "cbt-cn": cbt.CBTCN, # disabled pending context length fix
...@@ -124,7 +130,7 @@ TASK_REGISTRY = { ...@@ -124,7 +130,7 @@ TASK_REGISTRY = {
# Science related # Science related
"pubmedqa": pubmedqa.Pubmed_QA, "pubmedqa": pubmedqa.Pubmed_QA,
"sciq": sciq.SciQ, "sciq": sciq.SciQ,
# "e2e_nlg_cleaned": e2e_nlg_cleaned.E2E_NLG_Cleaned, "e2e_nlg_cleaned": e2e_nlg_cleaned.E2E_NLG_Cleaned,
"qasper": qasper.QASPER, "qasper": qasper.QASPER,
"qa4mre_2011": qa4mre.QA4MRE_2011, "qa4mre_2011": qa4mre.QA4MRE_2011,
"qa4mre_2012": qa4mre.QA4MRE_2012, "qa4mre_2012": qa4mre.QA4MRE_2012,
...@@ -293,10 +299,32 @@ TASK_REGISTRY = { ...@@ -293,10 +299,32 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance, "blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap, "blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance, "blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
#GEM/mlsum
"mlsum_es":gem_mlsum.GEMMLSUMEs,
"mlsum_de":gem_mlsum.GEMMLSUMDe,
"mlsum_es_covid_challenge_set":gem_mlsum.GEMMLSUMEsChallgeTestCovid,
"mlsum_de_covid_challenge_set":gem_mlsum.GEMMLSUMDeChallgeTestCovid,
# Requires manual download of data. # Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016, # "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018, # "storycloze_2018": storycloze.StoryCloze2018,
# "sat": sat.SATAnalogies, # "sat": sat.SATAnalogies,
#GEM/xum
"gem_xsum": gem_xsum.GEMXSUM,
"gem_xsum_challenge_sample": gem_xsum.GEMXSUMChallgeSample,
"gem_xsum_challenge_test_backtranslation": gem_xsum.GEMXSUMChallgeTestBacktranslation,
"gem_xsum_challenge_test_bfp_02": gem_xsum.GEMXSUMChallgeTestBFP02,
"gem_xsum_challenge_test_bfp_05": gem_xsum.GEMXSUMChallgeTestBFP05,
"gem_xsum_challenge_test_nopunc": gem_xsum.GEMXSUMChallgeTestNopunc,
"gem_xsum_challenge_test_covid": gem_xsum.GEMXSUMChallgeTestCovid,
# WinoBias
"wino_bias_type1_pro": wino_bias.WinoBiasType1Pro,
"wino_bias_type1_anti": wino_bias.WinoBiasType1Anti,
"wino_bias_type2_pro": wino_bias.WinoBiasType2Pro,
"wino_bias_type2_anti": wino_bias.WinoBiasType2Anti,
} }
......
...@@ -10,7 +10,7 @@ grammars. ...@@ -10,7 +10,7 @@ grammars.
Homepage: https://github.com/alexwarstadt/blimp Homepage: https://github.com/alexwarstadt/blimp
""" """
from lm_eval.base import rf, Task from lm_eval.base import rf, PromptSourceTask
from lm_eval.metrics import mean from lm_eval.metrics import mean
...@@ -31,7 +31,7 @@ _CITATION = """ ...@@ -31,7 +31,7 @@ _CITATION = """
""" """
class BlimpTask(Task): class BlimpTask(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "blimp" DATASET_PATH = "blimp"
...@@ -50,58 +50,6 @@ class BlimpTask(Task): ...@@ -50,58 +50,6 @@ class BlimpTask(Task):
# trained on this data. # trained on this data.
return self.dataset["train"] return self.dataset["train"]
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if provide_description is not None:
# nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
return ""
def doc_to_text(self, doc):
# this method is invoked by tests only
return ""
def doc_to_target(self, doc):
# this method is invoked by tests only
return ""
def construct_requests(self, doc, ctx):
assert not ctx
# Calculate the loglikelihood for the good and the bad sentence.
# Note that loglikelihood translates the "" prefix to the "<|endoftext|>" token
return [
rf.loglikelihood("", doc["sentence_good"]),
rf.loglikelihood("", doc["sentence_bad"]),
]
def process_results(self, doc, results):
likelihood1, likelihood2 = results
# the model got this case right iff the good sentence scored higher than the bad sentence
acc = 1.0 if likelihood1 > likelihood2 else 0.0
return {
"acc": acc,
}
def higher_is_better(self):
return {
"acc": True,
}
def aggregation(self):
return {
"acc": mean,
}
class BlimpAdjunctIsland(BlimpTask): class BlimpAdjunctIsland(BlimpTask):
DATASET_NAME = "adjunct_island" DATASET_NAME = "adjunct_island"
......
"""
Semantic Noise Matters for Neural Natural Language Generation
http://arxiv.org/abs/1911.03905
A cleaned version of the dataset from the E2E NLG Challenge.
The dataset contains MR with restaurant attributes and corresponding descriptions.
Homepage: https://github.com/tuetschek/e2e-cleaning
"""
from lm_eval.base import PromptSourceTask, rf
from lm_eval import metrics
_CITATION = """
@inproceedings{dusek-etal-2019-semantic,
title = "Semantic Noise Matters for Neural Natural Language Generation",
author = "Du{\v{s}}ek, Ond{\v{r}}ej and
Howcroft, David M. and
Rieser, Verena",
booktitle = "Proceedings of the 12th International Conference on Natural Language Generation",
year = "2019",
address = "Tokyo, Japan",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/W19-8652",
doi = "10.18653/v1/W19-8652",
pages = "421--426",
}
"""
# Work in progress
class E2E_NLG_Cleaned(PromptSourceTask):
VERSION = 0
DATASET_PATH = "e2e_nlg_cleaned"
DATASET_NAME = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self.has_training_docs():
# We cache training documents in `self._training_docs` for faster
# few-shot processing. If the data is too large to fit in memory,
# return the training data as a generator instead of a list.
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
def max_generation_length(self):
return 64
# def stopping_criteria(self):
# return '\n\n'
def invalid_doc_for_prompt(self, doc) -> bool:
"""The QA prompts are not applicable to all the examples, we want to filter these out."""
return self.prompt.name.endswith("_qa") or self.prompt.name == "family_friendly_yes_no"
def doc_to_text(self, doc) -> str:
# if the response is not defined in PS, the text will be a single-element list containing an empty string
text = self.prompt.apply(doc)[0]
return text
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
_requests = []
# NOTE: In the future, target will be a list of strings.
request_args = {
"stopping_criteria": self.stopping_criteria(),
"max_generation_length": self.max_generation_length(),
}
# Skip examples for which the templates are not applicable
if ctx != "":
cont_request = rf.greedy_until(ctx, request_args)
_requests.append(cont_request)
return _requests
"""
ASSET: ASSET (Alva-Manchego et al., 2020) is multi-reference dataset
for the evaluation of sentence simplification in English. The dataset
uses the same 2,359 sentences from TurkCorpus (Xu et al., 2016)
and each sentence is associated with 10 crowdsourced simplifications.
Unlike previous simplification datasets, which contain a single
transformation (e.g., lexical paraphrasing in TurkCorpus or sentence
splitting in HSplit), the simplifications in ASSET encompass a variety
of rewriting transformations.
https://aclanthology.org/2020.acl-main.424.pdf
TurkCorpus: TURKCorpus is a multi-reference dataset for the evaluation of
sentence simplification in English. The dataset consists of 2,359 sentences
from the Parallel Wikipedia Simplification (PWKP) corpus. Each sentence is
associated with 8 crowdsourced simplifications that focus on only lexical
paraphrasing (no sentence splitting or deletion).
https://cocoxu.github.io/publications/tacl2016-smt-simplification.pdf
"""
from lm_eval.base import PromptSourceTask
_CITATION = """
@article{DBLP:journals/corr/abs-2005-00481,
author = {Fernando Alva{-}Manchego and
Louis Martin and
Antoine Bordes and
Carolina Scarton and
Beno{\^{\i}}t Sagot and
Lucia Specia},
title = {{ASSET:} {A} Dataset for Tuning and Evaluation of Sentence Simplification
Models with Multiple Rewriting Transformations},
journal = {CoRR},
volume = {abs/2005.00481},
year = {2020},
url = {https://arxiv.org/abs/2005.00481},
eprinttype = {arXiv},
eprint = {2005.00481},
timestamp = {Thu, 14 Oct 2021 16:38:25 +0200},
biburl = {https://dblp.org/rec/journals/corr/abs-2005-00481.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}"""
""""@article{Xu-EtAl:2016:TACL,
author = {Wei Xu and Courtney Napoles and Ellie Pavlick and Quanze Chen and Chris Callison-Burch},
title = {Optimizing Statistical Machine Translation for Text Simplification},
journal = {Transactions of the Association for Computational Linguistics},
volume = {4},
year = {2016},
url = {https://cocoxu.github.io/publications/tacl2016-smt-simplification.pdf},
pages = {401--415}
}"""
class AssetTurk(PromptSourceTask):
VERSION = 0
DATASET_PATH = "GEM/wiki_auto_asset_turk"
DATASET_NAME = None
SPLIT = None
def has_training_docs(self):
return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
return self.dataset[str(self.SPLIT)]
def stopping_criteria(self):
return None
def max_generation_length(self):
return 200
# def higher_is_better(self):
# return {"bleu": True, "rouge": True}
class AssetTest(AssetTurk):
SPLIT = "test_asset"
class TurkTest(AssetTurk):
SPLIT = "test_turk"
class AssetTest1(AssetTurk):
SPLIT = "challenge_test_asset_backtranslation"
class AssetTest2(AssetTurk):
SPLIT = "challenge_test_asset_bfp02"
class AssetTest3(AssetTurk):
SPLIT = "challenge_test_asset_bfp05"
class AssetTest4(AssetTurk):
SPLIT = "challenge_test_asset_nopunc"
class TurkTest1(AssetTurk):
SPLIT = "challenge_test_turk_backtranslation"
class TurkTest2(AssetTurk):
SPLIT = "challenge_test_turk_bfp02"
class TurkTest3(AssetTurk):
SPLIT = "challenge_test_turk_bfp05"
class TurkTest4(AssetTurk):
SPLIT = "challenge_test_turk_nopunc"
ASSET_TURK_CLASSES = [
AssetTest,
TurkTest,
TurkTest1,
TurkTest2,
TurkTest3,
TurkTest4,
AssetTest1,
AssetTest2,
AssetTest3,
AssetTest4,
]
def construct_tasks():
tasks = {}
for asset_turk_class in ASSET_TURK_CLASSES:
tasks[f"GEM/wiki_auto_asset_turk_{asset_turk_class.SPLIT}"] = asset_turk_class
return tasks
"""
MLSUM: The Multilingual Summarization Corpus
https://aclanthology.org/2020.emnlp-main.647/
This is the MLSUM subset of the GEM benchmark. MLSUM is the first large-scale MultiLingual SUMmarization dataset.
Obtained from online newspapers, it contains 1.5M+ article/summary pairs in five different languages -- namely, French, German, Spanish, Russian, Turkish.
Together with English newspapers from the popular CNN/Daily mail dataset, the collected data form a large scale multilingual dataset which can enable new research directions for the text summarization community.
We report cross-lingual comparative analyses based on state-of-the-art systems.
These highlight existing biases which motivate the use of a multi-lingual dataset.
Homepage: https://gitlab.lip6.fr/scialom/mlsum_data/-/raw/master/MLSUM/
"""
from numpy import True_
from lm_eval.base import PromptSourceTask
_CITATION = """
@article{scialom2020mlsum,
title={MLSUM: The Multilingual Summarization Corpus},
author={Scialom, Thomas and Dray, Paul-Alexis and Lamprier, Sylvain and Piwowarski, Benjamin and Staiano, Jacopo},
journal={arXiv preprint arXiv:2004.14900},
year={2020}
}
"""
class GEMMLSUMEsBase(PromptSourceTask):
VERSION = 0
DATASET_PATH = "GEM/mlsum"
DATASET_NAME = "es"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
def stopping_criteria(self):
return "."
class GEMMLSUMEs(GEMMLSUMEsBase):
'''this is for train/validation/test'''
SPLIT = ''
class GEMMLSUMEsChallgeTestCovid(GEMMLSUMEsBase):
'''this is for challenge_test_covid'''
SPLIT = 'challenge_test_covid'
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def test_docs(self):
if self.has_test_docs():
return self.dataset[self.SPLIT]
class GEMMLSUMDeBase(PromptSourceTask):
VERSION = 0
DATASET_PATH = "GEM/mlsum"
DATASET_NAME = "de"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
def stopping_criteria(self):
return "."
class GEMMLSUMDe(GEMMLSUMDeBase):
'''this is for train/validation/test'''
SPLIT = ''
class GEMMLSUMDeChallgeTestCovid(GEMMLSUMDeBase):
'''this is for challenge_test_covid'''
SPLIT = 'challenge_test_covid'
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def test_docs(self):
if self.has_test_docs():
return self.dataset[self.SPLIT]
"""
The 2020 Bilingual, Bi-Directional WebNLG+ Shared Task:
Overview and Evaluation Results (WebNLG+ 2020)
https://aclanthology.org/2020.webnlg-1.7/
WebNLG+ offers two challenges: (i) mapping sets of RDF triples
to English or Russian text (generation) and (ii) converting
English or Russian text to sets of RDF triples (semantic parsing).
Compared to the eponymous WebNLG challenge, WebNLG+ provides an
extended dataset that enable the training, evaluation, and
comparison of microplanners and semantic parsers. In this paper,
we present the results of the generation and semantic parsing
task for both English and Russian and provide a brief
description of the participating systems.
"""
from lm_eval.base import PromptSourceTask from lm_eval.base import PromptSourceTask
_CITATION = """
@inproceedings{castro-ferreira-etal-2020-2020,
title = "The 2020 Bilingual, Bi-Directional {W}eb{NLG}+ Shared Task: Overview and Evaluation Results ({W}eb{NLG}+ 2020)",
author = "Castro Ferreira, Thiago and
Gardent, Claire and
Ilinykh, Nikolai and
van der Lee, Chris and
Mille, Simon and
Moussallem, Diego and
Shimorina, Anastasia",
booktitle = "Proceedings of the 3rd International Workshop on Natural Language Generation from the Semantic Web (WebNLG+)",
month = "12",
year = "2020",
address = "Dublin, Ireland (Virtual)",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2020.webnlg-1.7",
pages = "55--76",
abstract = "WebNLG+ offers two challenges: (i) mapping sets of RDF triples to English or Russian text (generation) and (ii) converting English or Russian text to sets of RDF triples (semantic parsing). Compared to the eponymous WebNLG challenge, WebNLG+ provides an extended dataset that enable the training, evaluation, and comparison of microplanners and semantic parsers. In this paper, we present the results of the generation and semantic parsing task for both English and Russian and provide a brief description of the participating systems.",
}
"""
class WebNLG(PromptSourceTask): class WebNLG(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "GEM/web_nlg" DATASET_PATH = "GEM/web_nlg"
DATASET_NAME = "en" DATASET_NAME = "en"
SPLIT = None
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -27,11 +65,71 @@ class WebNLG(PromptSourceTask): ...@@ -27,11 +65,71 @@ class WebNLG(PromptSourceTask):
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
if self.SPLIT is not None:
return self.dataset[str(self.SPLIT)]
else:
return self.dataset["test"] return self.dataset["test"]
def stopping_criteria(self): def stopping_criteria(self):
return '*' return None
def max_generation_length(self): def max_generation_length(self):
return 250 return 250
# def higher_is_better(self):
# return {"bleu": True, "rouge": True}
class WebNLGRu(WebNLG):
DATASET_NAME = "ru"
## En Challenge Sets
class WebNLGEn1(WebNLG):
SPLIT = "challenge_validation_sample"
class WebNLGEn2(WebNLG):
SPLIT = "challenge_test_scramble"
class WebNLGEn3(WebNLG):
SPLIT = "challenge_test_numbers"
## Ru Challenge sets
class WebNLGRu1(WebNLG):
DATASET_NAME = "ru"
SPLIT = "challenge_validation_sample"
class WebNLGRu2(WebNLG):
DATASET_NAME = "ru"
SPLIT = "challenge_test_scramble"
WEBNLG_CLASSES = [
WebNLG,
WebNLGRu,
WebNLGEn1,
WebNLGEn2,
WebNLGEn3,
WebNLGRu1,
WebNLGRu2,
]
def construct_tasks():
tasks = {}
for webnlg_class in WEBNLG_CLASSES:
if webnlg_class.SPLIT is None:
tasks[f"GEM/web_nlg_{webnlg_class.DATASET_NAME}"] = webnlg_class
else:
tasks[
f"GEM/web_nlg_{webnlg_class.DATASET_NAME}_{webnlg_class.SPLIT}"
] = webnlg_class
return tasks
"""
Don’t Give Me the Details, Just the Summary! Topic-Aware Convolutional Neural Networks for Extreme Summarization
https://arxiv.org/pdf/1808.08745.pdf
The dataset is for the task of abstractive summarization in its extreme form, its about summarizing a document in a single sentence. It introduces extreme summarization, a new single-document summarization task which does not favor extractive strategies and calls for an abstractive modeling approach. The idea is to create a short, one-sentence news summary answering the question "What is the article about?".
This particularly uses the dataset that is part of the GEM benchmark
Homepage: https://github.com/EdinburghNLP/XSum
The GEM Benchmark: Natural Language Generation, its Evaluation and Metrics
https://arxiv.org/pdf/2102.01672v3.pdf
Write a Short Description of the task.
Homepage: https://gem-benchmark.com/data_cards/XSum
"""
from lm_eval.base import PromptSourceTask
from lm_eval.base import Task, rf
_CITATION = """
@InProceedings{xsum-emnlp,
author = "Shashi Narayan and Shay B. Cohen and Mirella Lapata",
title = "Don't Give Me the Details, Just the Summary! {T}opic-Aware Convolutional Neural Networks for Extreme Summarization",
booktitle = "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing ",
year = "2018",
address = "Brussels, Belgium",
}
"""
class GEMXSUMBase(PromptSourceTask):
VERSION = 0
DATASET_PATH = "GEM/xsum"
DATASET_NAME = None
SPLIT = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def stopping_criteria(self):
return '.'
def training_docs(self):
if self.has_training_docs():
# We cache training documents in `self._training_docs` for faster
# few-shot processing. If the data is too large to fit in memory,
# return the training data as a generator instead of a list.
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
class GEMXSUM(GEMXSUMBase):
'''this is for train/validation/test'''
SPLIT = ''
class GEMXSUMChallgeSample(GEMXSUMBase):
'''this is for challenge_train_sample/challenge_validation_sample'''
SPLIT = 'challenge_sample'
def has_test_docs(self):
return False
def training_docs(self):
if self.has_training_docs():
# We cache training documents in `self._training_docs` for faster
# few-shot processing. If the data is too large to fit in memory,
# return the training data as a generator instead of a list.
if self._training_docs is None:
self._training_docs = list(self.dataset["challenge_train_sample"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["challenge_validation_sample"]
class GEMXSUMChallgeTestBacktranslation(GEMXSUMBase):
'''this is for challenge_test_backtranslation'''
SPLIT = 'challenge_test_backtranslation'
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def test_docs(self):
if self.has_test_docs():
return self.dataset[self.SPLIT]
class GEMXSUMChallgeTestBFP02(GEMXSUMBase):
'''this is for challenge_test_bfp_02'''
SPLIT = 'challenge_test_bfp_02'
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def test_docs(self):
if self.has_test_docs():
return self.dataset[self.SPLIT]
class GEMXSUMChallgeTestBFP05(GEMXSUMBase):
'''this is for challenge_test_bfp_05'''
SPLIT = 'challenge_test_bfp_05'
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def test_docs(self):
if self.has_test_docs():
return self.dataset[self.SPLIT]
class GEMXSUMChallgeTestNopunc(GEMXSUMBase):
'''this is for challenge_test_nopunc'''
SPLIT = 'challenge_test_nopunc'
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def test_docs(self):
if self.has_test_docs():
return self.dataset[self.SPLIT]
class GEMXSUMChallgeTestCovid(GEMXSUMBase):
'''this is for challenge_test_covid'''
SPLIT = 'challenge_test_covid'
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def test_docs(self):
if self.has_test_docs():
return self.dataset[self.SPLIT]
\ No newline at end of file
...@@ -305,3 +305,39 @@ class SGWinogradSchemaChallenge(PromptSourceTask): ...@@ -305,3 +305,39 @@ class SGWinogradSchemaChallenge(PromptSourceTask):
def aggregation(self): def aggregation(self):
return {"acc": mean} return {"acc": mean}
class WinogenderSchemaDiagnostics(PromptSourceTask):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "axg"
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def test_docs(self):
return self.dataset["test"]
class BroadcoverageDiagnostics(PromptSourceTask):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "axb"
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def test_docs(self):
return self.dataset["test"]
"""
Gender Bias in Coreference Resolution: Evaluation and Debiasing Methods
https://arxiv.org/abs/1804.06876
Winograd-schema evaluation of gendered coreference resolution.
The dataset contains pro-stereotypical and anti-stereotypical parts. The difference in accuracy for those two subsets
quatnifies bias.
Homepage: https://uclanlp.github.io/corefBias/overview
"""
from lm_eval.base import PromptSourceTask, mean
import transformers.data.metrics.squad_metrics as squad_metrics
_CITATION = """
@inproceedings{zhao-etal-2018-gender,
title = "Gender Bias in Coreference Resolution: Evaluation and Debiasing Methods",
author = "Zhao, Jieyu and
Wang, Tianlu and
Yatskar, Mark and
Ordonez, Vicente and
Chang, Kai-Wei",
booktitle = "Proceedings of the 2018 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 2 (Short Papers)",
month = jun,
year = "2018",
address = "New Orleans, Louisiana",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/N18-2003",
doi = "10.18653/v1/N18-2003",
pages = "15--20",
abstract = "In this paper, we introduce a new benchmark for co-reference resolution focused on gender bias, WinoBias. Our corpus contains Winograd-schema style sentences with entities corresponding to people referred by their occupation (e.g. the nurse, the doctor, the carpenter). We demonstrate that a rule-based, a feature-rich, and a neural coreference system all link gendered pronouns to pro-stereotypical entities with higher accuracy than anti-stereotypical entities, by an average difference of 21.1 in F1 score. Finally, we demonstrate a data-augmentation approach that, in combination with existing word-embedding debiasing techniques, removes the bias demonstrated by these systems in WinoBias without significantly affecting their performance on existing datasets.",
}
"""
class WinoBias(PromptSourceTask):
VERSION = 0
DATASET_PATH = "wino_bias"
def has_training_docs(self):
return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
pass
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def stopping_criteria(self):
return "\n"
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
target = self.doc_to_target(doc).strip()
pred = " ".join(results[0].strip().split(" ")[:len(target.split(" "))])
# The original paper uses F1. In the case of exactly one predicted and one gold mention,
# F1 and exact match are equivalent.
em = squad_metrics.compute_exact(target, pred)
out = {"em": em}
if self.save_examples:
example = {"target": target, "pred": pred}
return out, example
return out
def aggregation(self):
"""
:returns: {str: [metric_score] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metric scores
"""
return {'em': mean}
def higher_is_better(self):
return {'em': True}
class WinoBiasType1Pro(WinoBias):
DATASET_NAME = "type1_pro"
class WinoBiasType1Anti(WinoBias):
DATASET_NAME = "type1_anti"
class WinoBiasType2Pro(WinoBias):
DATASET_NAME = "type2_pro"
class WinoBiasType2Anti(WinoBias):
DATASET_NAME = "type2_anti"
...@@ -11,14 +11,14 @@ EXAMPLE_DIVIDER = "!!@@##@@!! -- Example {i}\n" ...@@ -11,14 +11,14 @@ EXAMPLE_DIVIDER = "!!@@##@@!! -- Example {i}\n"
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--output_base_path', required=True) parser.add_argument("--output_base_path", required=True)
parser.add_argument('--tasks', default="all_tasks") parser.add_argument("--tasks", default="all_tasks")
parser.add_argument('--provide_description', action="store_true") parser.add_argument("--provide_description", action="store_true")
parser.add_argument('--sets', type=str, default="val") # example: val,test parser.add_argument("--sets", type=str, default="val") # example: val,test
parser.add_argument('--num_fewshot', type=int, default=1) parser.add_argument("--num_fewshot", type=int, default=1)
parser.add_argument('--seed', type=int, default=42) parser.add_argument("--seed", type=int, default=42)
parser.add_argument('--num_examples', type=int, default=1) parser.add_argument("--num_examples", type=int, default=1)
parser.add_argument('--description_dict_path', default=None) parser.add_argument("--description_dict_path", default=None)
return parser.parse_args() return parser.parse_args()
...@@ -34,7 +34,7 @@ def main(): ...@@ -34,7 +34,7 @@ def main():
description_dict = {} description_dict = {}
if args.description_dict_path: if args.description_dict_path:
with open(args.description_dict_path, 'r') as f: with open(args.description_dict_path, "r") as f:
description_dict = json.load(f) description_dict = json.load(f)
os.makedirs(args.output_base_path, exist_ok=True) os.makedirs(args.output_base_path, exist_ok=True)
...@@ -45,26 +45,34 @@ def main(): ...@@ -45,26 +45,34 @@ def main():
iters = [] iters = []
for set in args.sets.split(","): for set in args.sets.split(","):
if set == 'train' and task.has_training_docs(): if set == "train" and task.has_training_docs():
docs = task.training_docs() docs = task.training_docs()
if set == 'val' and task.has_validation_docs(): if set == "val" and task.has_validation_docs():
docs = task.validation_docs() docs = task.validation_docs()
if set == 'test' and task.has_test_docs(): if set == "test" and task.has_test_docs():
docs = task.test_docs() docs = task.test_docs()
iters.append(docs) iters.append(docs)
docs = join_iters(iters) docs = join_iters(iters)
description = description_dict[task_name] if description_dict and task_name in description_dict else "" description = (
task_name = task_name.replace('/','_') description_dict[task_name]
if description_dict and task_name in description_dict
else ""
)
task_name = task_name.replace("/", "_")
with open(os.path.join(args.output_base_path, task_name), "w") as f: with open(os.path.join(args.output_base_path, task_name), "w") as f:
for i, doc in zip(range(args.num_examples), docs) if args.num_examples > 0 else enumerate(docs): for i, doc in (
zip(range(args.num_examples), docs)
if args.num_examples > 0
else enumerate(docs)
):
f.write(EXAMPLE_DIVIDER.format(i=i)) f.write(EXAMPLE_DIVIDER.format(i=i))
ctx = task.fewshot_context( ctx, _ = task.fewshot_context(
doc=doc, doc=doc,
num_fewshot=args.num_fewshot, num_fewshot=args.num_fewshot,
rnd=rnd, rnd=rnd,
description=description description=description,
) )
f.write(ctx + "\n") f.write(ctx + "\n")
......
...@@ -29,7 +29,7 @@ setuptools.setup( ...@@ -29,7 +29,7 @@ setuptools.setup(
"click>=7.1", "click>=7.1",
"scikit-learn>=0.24.1", "scikit-learn>=0.24.1",
"torch>=1.7", "torch>=1.7",
"transformers>=4.1", "transformers>=4.16",
"sqlitedict==1.6.0", "sqlitedict==1.6.0",
"pytablewriter==0.58.0", "pytablewriter==0.58.0",
"sacrebleu==1.5.0", "sacrebleu==1.5.0",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment