Commit b7f070d8 authored by Zdenek Kasner's avatar Zdenek Kasner
Browse files

Merge remote-tracking branch 'origin/master' into kasnerz/generation_tasks

parents 9aba053a 9cd70235
......@@ -173,10 +173,6 @@ def evaluate(
# get lists of each type of request
for task_prompt_name, task in task_dict_items:
# if task.is_generation_task():
# print(f"WARNING: Skipping generation prompt {task.prompt.name}.")
# continue
versions[task_prompt_name] = task.VERSION
# default to test doc, fall back to val doc if validation unavailable
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
......@@ -188,7 +184,7 @@ def evaluate(
raise RuntimeError("Task has neither test_docs nor validation_docs")
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
task_docs = list(task_doc_func())
task_docs = list(enumerate(list(task_doc_func())))
rnd = random.Random()
rnd.seed(42)
rnd.shuffle(task_docs)
......@@ -199,14 +195,17 @@ def evaluate(
else ""
)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
for doc_id, (original_doc_id, doc) in enumerate(
itertools.islice(task_docs, 0, limit)
):
if task.invalid_doc_for_prompt(doc):
continue
docs[(task_prompt_name, doc_id)] = doc
ctx = task.fewshot_context(
ctx, fewshotex_logging_info = task.fewshot_context(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
fewshotex_logging_info["doc_id"] = original_doc_id
reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)):
reqs = [reqs]
......@@ -215,7 +214,7 @@ def evaluate(
# i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append(
(i, task_prompt_name, doc, doc_id)
(i, task_prompt_name, doc, doc_id, fewshotex_logging_info)
)
# all responses for each (task, doc)
......@@ -234,33 +233,57 @@ def evaluate(
x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
]
for resp, (i, task_prompt_name, doc, doc_id) in zip(
for resp, (i, task_prompt_name, doc, doc_id, fewshotex_logging_info) in zip(
resps, requests_origin[reqtype]
):
process_res_queue[(task_prompt_name, doc_id)].append((i, resp))
process_res_queue[(task_prompt_name, doc_id)].append(
(i, resp, fewshotex_logging_info)
)
vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task
for (task_prompt_name, doc_id), requests in process_res_queue.items():
requests.sort(key=lambda x: x[0])
requests = [x[1] for x in requests]
examples = []
for (task_prompt_name, doc_id), per_doc_requests in process_res_queue.items():
per_doc_requests.sort(key=lambda x: x[0])
per_doc_results = [x[1] for x in per_doc_requests]
fewshot_logging_info = [x[2] for x in per_doc_requests][0]
task = task_dict[task_prompt_name]
doc = docs[(task_prompt_name, doc_id)]
metrics = task.process_results(doc, requests)
output = task.process_results(doc, per_doc_results)
if task.save_examples:
metrics, example = output
example.update(fewshot_logging_info)
example.update(task.get_logging_info())
examples.append(example)
else:
metrics = output
example = fewshot_logging_info
example.update(task.get_logging_info())
examples.append(example)
for metric, value in metrics.items():
vals[(task_prompt_name, metric)].append(value)
# aggregate results
metric_results = []
for (task_prompt_name, metric), items in vals.items():
task_name, prompt_name = task_prompt_name.split("+")
results[task_prompt_name]["task_name"] = task_name
results[task_prompt_name]["prompt_name"] = prompt_name
task = task_dict[task_prompt_name]
results[task_prompt_name][metric] = task.aggregation()[metric](items)
_metric_results = {
"task_name": task_name,
"prompt_name": prompt_name,
metric: task.aggregation()[metric](items),
**task.get_logging_info(),
}
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric(
......@@ -271,8 +294,18 @@ def evaluate(
)
if stderr is not None:
results[task_prompt_name][metric + "_stderr"] = stderr(items)
return {"results": dict(results), "versions": dict(versions)}
_metric_results[metric + "_stderr"] = stderr(items)
metric_results.append(_metric_results)
return {
# List of results that tracks the averages per model and prompt.
"results": metric_results,
"versions": dict(versions),
# List of all prompt x doc examples with additional information in it.
"examples": examples,
# Original results used for generating the table when running this file.
"table_results": dict(results),
}
def make_table(result_dict):
......@@ -293,7 +326,7 @@ def make_table(result_dict):
]
values = []
for k, dic in result_dict["results"].items():
for k, dic in result_dict["table_results"].items():
version = result_dict["versions"][k]
for m, v in dic.items():
if m.endswith("_stderr"):
......
......@@ -72,6 +72,10 @@ class HFLM(BaseLM):
# if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2)
@property
def eot_token(self):
return self.tokenizer.eos_token
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
......
......@@ -39,6 +39,10 @@ class GPTJLM(BaseLM):
# if gpus > 1:
# self.gptj = nn.DataParallel(self.gptj)
@property
def eot_token(self):
return self.tokenizer.eos_token
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
......
......@@ -2,37 +2,36 @@ import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
from lm_eval.base import LM
from lm_eval.base import BaseLM
from lm_eval import utils
from tqdm import tqdm
import numpy as np
import math
class T0LM(LM):
MAX_GEN_TOKS = 256
MAX_INP_LENGTH = 512
VOCAB_SIZE = 32100
EOT_TOKEN_ID = 1
class T0LM(BaseLM):
# MAX_GEN_TOKS = 256
# MAX_INP_LENGTH = 512
# VOCAB_SIZE = 32100
# EOT_TOKEN_ID = 1
def __init__(self, device='cuda', parallelize=False, pretrained='t0', batch_size=1):
super().__init__()
if device:
self.device = torch.device(device)
self._device = torch.device(device)
else:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(pretrained)
self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.t0 = transformers.AutoModelForSeq2SeqLM.from_pretrained(pretrained)
self.t0.eval()
if parallelize == "True":
print(parallelize)
self.t0.parallelize()
self.device = torch.device('cuda:0')
self._device = torch.device('cuda:0')
else:
self.t0.to(self.device)
self.t0.to(self._device)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
self.max_length = self.MAX_INP_LENGTH
# self.max_length = self.MAX_INP_LENGTH
self.batch_size = int(batch_size)
......@@ -42,6 +41,53 @@ class T0LM(LM):
args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2)
@property
def eot_token(self):
return self.tokenizer.eos_token
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
return self.tokenizer.model_max_length
@property
def max_gen_toks(self):
return self.tokenizer.model_max_length
@property
def batch_size(self):
# TODO: fix multi-gpu
return self._batch_size # * gpus
@property
def device(self):
# TODO: fix multi-gpu
return self._device
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def _model_call(self, inputs_tok, targets_tok):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
return self.t0(
**inputs_tok,
labels=targets_tok["input_ids"]
)
def loglikelihood(self, requests):
res = []
for chunk in tqdm(utils.chunks(requests, self.batch_size), total=math.ceil(len(requests)/self.batch_size)):
......@@ -62,7 +108,7 @@ class T0LM(LM):
targets_tok = self.tokenizer(
list(targets),
max_length=self.MAX_GEN_TOKS,
max_length=self.max_gen_toks,
padding=True,
# truncation=True,
add_special_tokens=False,
......@@ -72,11 +118,7 @@ class T0LM(LM):
for key in targets_tok:
targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :]
with torch.no_grad():
outputs = self.t0(
**inputs_tok,
labels=targets_tok["input_ids"]
)
outputs = self._model_call(inputs_tok, targets_tok)
log_softmaxes = F.log_softmax(outputs.logits, dim=-1)
......@@ -103,9 +145,6 @@ class T0LM(LM):
res.append(answer)
return res
def loglikelihood_rolling(self, requests):
raise NotImplementedError
def _get_stopping_criteria(self, stopping_criteria_ids):
class MultitokenEOSCriteria(transformers.StoppingCriteria):
......@@ -133,29 +172,11 @@ class T0LM(LM):
EOSCriteria(self.tokenizer.eos_token)
])
def greedy_until(self, requests):
res = []
for context, until in tqdm(requests):
if isinstance(until, str): until = [until]
context_enc = self.tokenizer(context, return_tensors="pt").to(self.device).input_ids
stopping_criteria_ids = self.tokenizer.encode(until[0])
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
cont = self.t0.generate(
context_enc,
max_length=self.MAX_GEN_TOKS,
stopping_criteria=stopping_criteria,
do_sample=False
)
s = self.tokenizer.decode(cont[0].tolist())
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return res
\ No newline at end of file
def _model_generate(self, context, max_length, stopping_criteria_ids):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.t0.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
......@@ -2,39 +2,44 @@ import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
from lm_eval.base import LM
from lm_eval.base import BaseLM
from lm_eval import utils
from tqdm import tqdm
import numpy as np
import math
class T5LM(LM):
MAX_GEN_TOKS = 256
MAX_INP_LENGTH = 512
VOCAB_SIZE = 32128
EOT_TOKEN_ID = 1
def __init__(self, device='cuda', parallelize=False, pretrained='t5', batch_size=1):
class T5LM(BaseLM):
# MAX_GEN_TOKS = 256
# MAX_INP_LENGTH = 512
# VOCAB_SIZE = 32128
# EOT_TOKEN_ID = 1
def __init__(
self,
device='cuda',
parallelize=False,
pretrained='t5',
batch_size=1
):
super().__init__()
if device:
self.device = torch.device(device)
self._device = torch.device(device)
else:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(pretrained)
self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.t5 = transformers.AutoModelForSeq2SeqLM.from_pretrained(pretrained)
self.t5.eval()
if parallelize == "True":
print(parallelize)
self.t5.parallelize()
self.device = torch.device('cuda:0')
self._device = torch.device('cuda:0')
else:
self.t5.to(self.device)
self.t5.to(self._device)
self.tokenizer = transformers.T5TokenizerFast.from_pretrained(pretrained)
self.max_length = self.MAX_INP_LENGTH
self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
# self.max_length = self.MAX_INP_LENGTH
self.batch_size = int(batch_size)
self._batch_size = int(batch_size)
@classmethod
def create_from_arg_string(cls, arg_string, additional_config={}):
......@@ -42,12 +47,67 @@ class T5LM(LM):
args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2)
@property
def eot_token(self):
return self.tokenizer.eos_token
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
return self.tokenizer.model_max_length
@property
def max_gen_toks(self):
return self.tokenizer.model_max_length
@property
def batch_size(self):
# TODO: fix multi-gpu
return self._batch_size # * gpus
@property
def device(self):
# TODO: fix multi-gpu
return self._device
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def _model_call(self, inputs_tok, targets_tok):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
return self.t5(
**inputs_tok,
labels=targets_tok["input_ids"]
)
def loglikelihood(self, requests):
res = []
for chunk in tqdm(utils.chunks(requests, self.batch_size), total=math.ceil(len(requests)/self.batch_size)):
inputs, targets = zip(*chunk)
# Fill in empty encoder inputs with eos_token
inputs = (
f"{self.eot_token}"
if len(input_) == 0
else input_
for input_ in inputs
)
inputs_tok = self.tokenizer(
list(inputs),
max_length=self.max_length,
......@@ -62,7 +122,7 @@ class T5LM(LM):
targets_tok = self.tokenizer(
list(targets),
max_length=self.MAX_GEN_TOKS,
max_length=self.max_gen_toks,
padding=True,
# truncation=True,
add_special_tokens=False,
......@@ -71,12 +131,8 @@ class T5LM(LM):
for key in targets_tok:
targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :]
with torch.no_grad():
outputs = self.t5(
**inputs_tok,
labels=targets_tok["input_ids"]
)
outputs = self._model_call(inputs_tok, targets_tok)
log_softmaxes = F.log_softmax(outputs.logits, dim=-1)
......@@ -103,9 +159,6 @@ class T5LM(LM):
res.append(answer)
return res
def loglikelihood_rolling(self, requests):
raise NotImplementedError
def _get_stopping_criteria(self, stopping_criteria_ids):
class MultitokenEOSCriteria(transformers.StoppingCriteria):
......@@ -133,29 +186,11 @@ class T5LM(LM):
EOSCriteria(self.tokenizer.eos_token)
])
def greedy_until(self, requests):
res = []
for context, until in tqdm(requests):
if isinstance(until, str): until = [until]
context_enc = self.tokenizer(context, return_tensors="pt").to(self.device).input_ids
stopping_criteria_ids = self.tokenizer.encode(until[0])
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
cont = self.t5.generate(
context_enc,
max_length=self.MAX_GEN_TOKS,
stopping_criteria=stopping_criteria,
do_sample=False
)
s = self.tokenizer.decode(cont[0].tolist())
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return res
\ No newline at end of file
def _model_generate(self, context, max_length, stopping_criteria_ids):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.t5.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
......@@ -53,6 +53,9 @@ from . import asdiv
from . import gsm8k
from . import storycloze
from . import hans
from . import gem_webnlg
from . import gem_xsum
from . import gem_mlsum
from . import e2e_nlg_cleaned
########################################
......@@ -107,6 +110,7 @@ TASK_REGISTRY = {
"wsc": superglue.SGWinogradSchemaChallenge,
# Order by benchmark/genre?
"coqa": coqa.CoQA,
"GEM/web_nlg": gem_webnlg.WebNLG,
"drop": drop.DROP,
"lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze,
......@@ -284,10 +288,26 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
#GEM/mlsum
"mlsum_es":gem_mlsum.GEMMLSUMEs,
"mlsum_de":gem_mlsum.GEMMLSUMDe,
"mlsum_es_covid_challenge_set":gem_mlsum.GEMMLSUMEsChallgeTestCovid,
"mlsum_de_covid_challenge_set":gem_mlsum.GEMMLSUMDeChallgeTestCovid,
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
# "sat": sat.SATAnalogies,
#GEM/xum
"gem_xsum": gem_xsum.GEMXSUM,
"gem_xsum_challenge_sample": gem_xsum.GEMXSUMChallgeSample,
"gem_xsum_challenge_test_backtranslation": gem_xsum.GEMXSUMChallgeTestBacktranslation,
"gem_xsum_challenge_test_bfp_02": gem_xsum.GEMXSUMChallgeTestBFP02,
"gem_xsum_challenge_test_bfp_05": gem_xsum.GEMXSUMChallgeTestBFP05,
"gem_xsum_challenge_test_nopunc": gem_xsum.GEMXSUMChallgeTestNopunc,
"gem_xsum_challenge_test_covid": gem_xsum.GEMXSUMChallgeTestCovid,
}
......
......@@ -10,7 +10,7 @@ grammars.
Homepage: https://github.com/alexwarstadt/blimp
"""
from lm_eval.base import rf, Task
from lm_eval.base import rf, PromptSourceTask
from lm_eval.metrics import mean
......@@ -31,7 +31,7 @@ _CITATION = """
"""
class BlimpTask(Task):
class BlimpTask(PromptSourceTask):
VERSION = 0
DATASET_PATH = "blimp"
......@@ -50,58 +50,6 @@ class BlimpTask(Task):
# trained on this data.
return self.dataset["train"]
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if provide_description is not None:
# nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
return ""
def doc_to_text(self, doc):
# this method is invoked by tests only
return ""
def doc_to_target(self, doc):
# this method is invoked by tests only
return ""
def construct_requests(self, doc, ctx):
assert not ctx
# Calculate the loglikelihood for the good and the bad sentence.
# Note that loglikelihood translates the "" prefix to the "<|endoftext|>" token
return [
rf.loglikelihood("", doc["sentence_good"]),
rf.loglikelihood("", doc["sentence_bad"]),
]
def process_results(self, doc, results):
likelihood1, likelihood2 = results
# the model got this case right iff the good sentence scored higher than the bad sentence
acc = 1.0 if likelihood1 > likelihood2 else 0.0
return {
"acc": acc,
}
def higher_is_better(self):
return {
"acc": True,
}
def aggregation(self):
return {
"acc": mean,
}
class BlimpAdjunctIsland(BlimpTask):
DATASET_NAME = "adjunct_island"
......
......@@ -118,25 +118,18 @@ class CoQA(PromptSourceTask):
"""
target = self.doc_to_target(doc).strip()
pred = results[0].strip().split("\n")[0]
print("*" * 80)
print(f"DOC: {doc}")
# print(f"PS: {self.prompt.apply(doc)}")
print(f"TEXT: {self.doc_to_text(doc)}")
print(f"TARGET: {target} END TARGET")
print(f"PRED: {pred} END PRED")
print("*" * 80)
# turn_id = len(doc["questions"]["input_text"])
# gold_list = self.get_answers(doc, turn_id)
# TODO: Add HF metrics mapped from promptsource metadata.
scores = self.compute_scores([target], pred)
return {
out = {
"f1": scores["f1"],
"em": scores["em"],
}
if self.save_examples:
example = {"target": target, "pred": pred}
return out, example
return out
def higher_is_better(self):
return {
"f1": True,
......
"""
MLSUM: The Multilingual Summarization Corpus
https://aclanthology.org/2020.emnlp-main.647/
This is the MLSUM subset of the GEM benchmark. MLSUM is the first large-scale MultiLingual SUMmarization dataset.
Obtained from online newspapers, it contains 1.5M+ article/summary pairs in five different languages -- namely, French, German, Spanish, Russian, Turkish.
Together with English newspapers from the popular CNN/Daily mail dataset, the collected data form a large scale multilingual dataset which can enable new research directions for the text summarization community.
We report cross-lingual comparative analyses based on state-of-the-art systems.
These highlight existing biases which motivate the use of a multi-lingual dataset.
Homepage: https://gitlab.lip6.fr/scialom/mlsum_data/-/raw/master/MLSUM/
"""
from numpy import True_
from lm_eval.base import PromptSourceTask
_CITATION = """
@article{scialom2020mlsum,
title={MLSUM: The Multilingual Summarization Corpus},
author={Scialom, Thomas and Dray, Paul-Alexis and Lamprier, Sylvain and Piwowarski, Benjamin and Staiano, Jacopo},
journal={arXiv preprint arXiv:2004.14900},
year={2020}
}
"""
class GEMMLSUMEsBase(PromptSourceTask):
VERSION = 0
DATASET_PATH = "GEM/mlsum"
DATASET_NAME = "es"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
def stopping_criteria(self):
return "."
class GEMMLSUMEs(GEMMLSUMEsBase):
'''this is for train/validation/test'''
SPLIT = ''
class GEMMLSUMEsChallgeTestCovid(GEMMLSUMEsBase):
'''this is for challenge_test_covid'''
SPLIT = 'challenge_test_covid'
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def test_docs(self):
if self.has_test_docs():
return self.dataset[self.SPLIT]
class GEMMLSUMDeBase(PromptSourceTask):
VERSION = 0
DATASET_PATH = "GEM/mlsum"
DATASET_NAME = "de"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
def stopping_criteria(self):
return "."
class GEMMLSUMDe(GEMMLSUMDeBase):
'''this is for train/validation/test'''
SPLIT = ''
class GEMMLSUMDeChallgeTestCovid(GEMMLSUMDeBase):
'''this is for challenge_test_covid'''
SPLIT = 'challenge_test_covid'
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def test_docs(self):
if self.has_test_docs():
return self.dataset[self.SPLIT]
from lm_eval.base import PromptSourceTask
class WebNLG(PromptSourceTask):
VERSION = 0
DATASET_PATH = "GEM/web_nlg"
DATASET_NAME = "en"
def has_training_docs(self):
return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
def stopping_criteria(self):
return '*'
def max_generation_length(self):
return 250
"""
Don’t Give Me the Details, Just the Summary! Topic-Aware Convolutional Neural Networks for Extreme Summarization
https://arxiv.org/pdf/1808.08745.pdf
The dataset is for the task of abstractive summarization in its extreme form, its about summarizing a document in a single sentence. It introduces extreme summarization, a new single-document summarization task which does not favor extractive strategies and calls for an abstractive modeling approach. The idea is to create a short, one-sentence news summary answering the question "What is the article about?".
This particularly uses the dataset that is part of the GEM benchmark
Homepage: https://github.com/EdinburghNLP/XSum
The GEM Benchmark: Natural Language Generation, its Evaluation and Metrics
https://arxiv.org/pdf/2102.01672v3.pdf
Write a Short Description of the task.
Homepage: https://gem-benchmark.com/data_cards/XSum
"""
from lm_eval.base import PromptSourceTask
from lm_eval.base import Task, rf
_CITATION = """
@InProceedings{xsum-emnlp,
author = "Shashi Narayan and Shay B. Cohen and Mirella Lapata",
title = "Don't Give Me the Details, Just the Summary! {T}opic-Aware Convolutional Neural Networks for Extreme Summarization",
booktitle = "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing ",
year = "2018",
address = "Brussels, Belgium",
}
"""
class GEMXSUMBase(PromptSourceTask):
VERSION = 0
DATASET_PATH = "GEM/xsum"
DATASET_NAME = None
SPLIT = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def stopping_criteria(self):
return '.'
def training_docs(self):
if self.has_training_docs():
# We cache training documents in `self._training_docs` for faster
# few-shot processing. If the data is too large to fit in memory,
# return the training data as a generator instead of a list.
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
class GEMXSUM(GEMXSUMBase):
'''this is for train/validation/test'''
SPLIT = ''
class GEMXSUMChallgeSample(GEMXSUMBase):
'''this is for challenge_train_sample/challenge_validation_sample'''
SPLIT = 'challenge_sample'
def has_test_docs(self):
return False
def training_docs(self):
if self.has_training_docs():
# We cache training documents in `self._training_docs` for faster
# few-shot processing. If the data is too large to fit in memory,
# return the training data as a generator instead of a list.
if self._training_docs is None:
self._training_docs = list(self.dataset["challenge_train_sample"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["challenge_validation_sample"]
class GEMXSUMChallgeTestBacktranslation(GEMXSUMBase):
'''this is for challenge_test_backtranslation'''
SPLIT = 'challenge_test_backtranslation'
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def test_docs(self):
if self.has_test_docs():
return self.dataset[self.SPLIT]
class GEMXSUMChallgeTestBFP02(GEMXSUMBase):
'''this is for challenge_test_bfp_02'''
SPLIT = 'challenge_test_bfp_02'
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def test_docs(self):
if self.has_test_docs():
return self.dataset[self.SPLIT]
class GEMXSUMChallgeTestBFP05(GEMXSUMBase):
'''this is for challenge_test_bfp_05'''
SPLIT = 'challenge_test_bfp_05'
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def test_docs(self):
if self.has_test_docs():
return self.dataset[self.SPLIT]
class GEMXSUMChallgeTestNopunc(GEMXSUMBase):
'''this is for challenge_test_nopunc'''
SPLIT = 'challenge_test_nopunc'
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def test_docs(self):
if self.has_test_docs():
return self.dataset[self.SPLIT]
class GEMXSUMChallgeTestCovid(GEMXSUMBase):
'''this is for challenge_test_covid'''
SPLIT = 'challenge_test_covid'
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def test_docs(self):
if self.has_test_docs():
return self.dataset[self.SPLIT]
\ No newline at end of file
......@@ -277,20 +277,18 @@ class EthicsUtilitarianism(Ethics):
DATASET_NAME = "utilitarianism"
def training_docs(self):
rnd = random.Random()
for doc in self.dataset["train"]:
yield self._process_doc(doc, rnd)
yield self._process_doc(doc)
def validation_docs(self):
raise NotImplementedError
def test_docs(self):
rnd = random.Random()
for doc in self.dataset["test"]:
yield self._process_doc(doc, rnd)
yield self._process_doc(doc)
def _process_doc(self, doc, rnd):
rnd.seed(doc["activity"])
def _process_doc(self, doc):
rnd = random.Random(doc["activity"])
scenarios = [doc["activity"], doc["baseline"]]
ordering = [0, 1]
rnd.shuffle(ordering)
......
......@@ -38,15 +38,15 @@ class Math(Task):
return True
def training_docs(self):
return map(self._load_doc, self.dataset["train"])
return map(self._process_doc, self.dataset["train"])
def validation_docs(self):
return NotImplemented
def test_docs(self):
return map(self._load_doc, self.dataset["test"])
return map(self._process_doc, self.dataset["test"])
def _load_doc(self, doc):
def _process_doc(self, doc):
doc["answer"] = self.remove_boxed(
self.last_boxed_only_string(doc["solution"]))
return doc
......
......@@ -76,15 +76,15 @@ class WikiText(PerplexityTask):
return True
def training_docs(self):
return map(self._load_doc, self.dataset["train"])
return map(self._process_doc, self.dataset["train"])
def validation_docs(self):
return map(self._load_doc, self.dataset["validation"])
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._load_doc, self.dataset["test"])
return map(self._process_doc, self.dataset["test"])
def _load_doc(self, doc):
def _process_doc(self, doc):
return doc["page"]
def doc_to_target(self, doc):
......
......@@ -53,9 +53,9 @@ class WinogradSchemaChallenge273(Task):
return True
def test_docs(self):
return map(self._load_doc, self.dataset["test"])
return map(self._process_doc, self.dataset["test"])
def _load_doc(self, doc):
def _process_doc(self, doc):
# The HF implementation of `wsc273` is not `partial evaluation` friendly.
doc["text"] = doc["text"].replace(" ", " ")
doc["options"][0] = self.__normalize_option(doc, doc["options"][0])
......
......@@ -9,27 +9,29 @@ logging.getLogger("openai").setLevel(logging.WARNING)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True)
parser.add_argument('--model_args', default="")
parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None)
parser.add_argument('--device', type=str, default=None)
parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--no_cache', action="store_true")
parser.add_argument('--description_dict_path', default=None)
parser.add_argument('--check_integrity', action="store_true")
parser.add_argument("--model", required=True)
parser.add_argument("--model_args", default="")
parser.add_argument("--tasks", default="all_tasks")
parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=None)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--output_path", default=None)
parser.add_argument("--limit", type=int, default=None)
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--description_dict_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
return parser.parse_args()
def main():
args = parse_args()
assert not args.provide_description # not implemented
if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
print(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.tasks == "all_tasks":
task_names = tasks.ALL_TASKS
......@@ -38,7 +40,7 @@ def main():
description_dict = {}
if args.description_dict_path:
with open(args.description_dict_path, 'r') as f:
with open(args.description_dict_path, "r") as f:
description_dict = json.load(f)
results = evaluator.simple_evaluate(
......@@ -51,11 +53,12 @@ def main():
no_cache=args.no_cache,
limit=args.limit,
description_dict=description_dict,
check_integrity=args.check_integrity
check_integrity=args.check_integrity,
)
print(results)
dumped = json.dumps(results, indent=2)
print(dumped)
if args.output_path:
......
......@@ -56,11 +56,11 @@ def main():
docs = join_iters(iters)
description = description_dict[task_name] if description_dict and task_name in description_dict else ""
task_name = task_name.replace('/','_')
with open(os.path.join(args.output_base_path, task_name), "w") as f:
for i, doc in zip(range(args.num_examples), docs) if args.num_examples > 0 else enumerate(docs):
f.write(EXAMPLE_DIVIDER.format(i=i))
ctx = task.fewshot_context(
ctx, _ = task.fewshot_context(
doc=doc,
num_fewshot=args.num_fewshot,
rnd=rnd,
......
......@@ -20,7 +20,7 @@ setuptools.setup(
],
python_requires=">=3.6",
install_requires=[
"promptsource",
"promptsource @ git+https://github.com/bigscience-workshop/promptsource@eval-hackathon",
"wrapt",
"nltk",
"jinja2",
......@@ -37,7 +37,6 @@ setuptools.setup(
"pycountry==20.7.3",
"numexpr==2.7.2",
"lm_dataformat==0.0.20",
"pytest==6.2.3",
"pybind11==2.6.2",
"tqdm-multiprocess==0.0.11",
"zstandard==0.15.2",
......@@ -51,4 +50,5 @@ setuptools.setup(
dependency_links=[
"https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt",
],
extras_require={'dev': [ 'pytest', 'black' ]}
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment