Commit b4ad893c authored by ken's avatar ken
Browse files

Merge master

parents 8c83a821 20820c3c
...@@ -173,10 +173,6 @@ def evaluate( ...@@ -173,10 +173,6 @@ def evaluate(
# get lists of each type of request # get lists of each type of request
for task_prompt_name, task in task_dict_items: for task_prompt_name, task in task_dict_items:
# if task.is_generation_task():
# print(f"WARNING: Skipping generation prompt {task.prompt.name}.")
# continue
versions[task_prompt_name] = task.VERSION versions[task_prompt_name] = task.VERSION
# default to test doc, fall back to val doc if validation unavailable # default to test doc, fall back to val doc if validation unavailable
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
...@@ -188,7 +184,7 @@ def evaluate( ...@@ -188,7 +184,7 @@ def evaluate(
raise RuntimeError("Task has neither test_docs nor validation_docs") raise RuntimeError("Task has neither test_docs nor validation_docs")
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
task_docs = list(task_doc_func()) task_docs = list(enumerate(list(task_doc_func())))
rnd = random.Random() rnd = random.Random()
rnd.seed(42) rnd.seed(42)
rnd.shuffle(task_docs) rnd.shuffle(task_docs)
...@@ -199,14 +195,17 @@ def evaluate( ...@@ -199,14 +195,17 @@ def evaluate(
else "" else ""
) )
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): for doc_id, (original_doc_id, doc) in enumerate(
itertools.islice(task_docs, 0, limit)
):
if task.invalid_doc_for_prompt(doc): if task.invalid_doc_for_prompt(doc):
continue continue
docs[(task_prompt_name, doc_id)] = doc docs[(task_prompt_name, doc_id)] = doc
ctx = task.fewshot_context( ctx, fewshotex_logging_info = task.fewshot_context(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
) )
fewshotex_logging_info["doc_id"] = original_doc_id
reqs = task.construct_requests(doc, ctx) reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
reqs = [reqs] reqs = [reqs]
...@@ -215,7 +214,7 @@ def evaluate( ...@@ -215,7 +214,7 @@ def evaluate(
# i: index in requests for a single task instance # i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs` # doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append( requests_origin[req.request_type].append(
(i, task_prompt_name, doc, doc_id) (i, task_prompt_name, doc, doc_id, fewshotex_logging_info)
) )
# all responses for each (task, doc) # all responses for each (task, doc)
...@@ -234,33 +233,57 @@ def evaluate( ...@@ -234,33 +233,57 @@ def evaluate(
x if req.index is None else x[req.index] for x, req in zip(resps, reqs) x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
] ]
for resp, (i, task_prompt_name, doc, doc_id) in zip( for resp, (i, task_prompt_name, doc, doc_id, fewshotex_logging_info) in zip(
resps, requests_origin[reqtype] resps, requests_origin[reqtype]
): ):
process_res_queue[(task_prompt_name, doc_id)].append((i, resp)) process_res_queue[(task_prompt_name, doc_id)].append(
(i, resp, fewshotex_logging_info)
)
vals = collections.defaultdict(list) vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task # unpack results and sort back in order and return control to Task
for (task_prompt_name, doc_id), requests in process_res_queue.items(): examples = []
requests.sort(key=lambda x: x[0]) for (task_prompt_name, doc_id), per_doc_requests in process_res_queue.items():
requests = [x[1] for x in requests] per_doc_requests.sort(key=lambda x: x[0])
per_doc_results = [x[1] for x in per_doc_requests]
fewshot_logging_info = [x[2] for x in per_doc_requests][0]
task = task_dict[task_prompt_name] task = task_dict[task_prompt_name]
doc = docs[(task_prompt_name, doc_id)] doc = docs[(task_prompt_name, doc_id)]
metrics = task.process_results(doc, requests) output = task.process_results(doc, per_doc_results)
if task.save_examples:
metrics, example = output
example.update(fewshot_logging_info)
example.update(task.get_logging_info())
examples.append(example)
else:
metrics = output
example = fewshot_logging_info
example.update(task.get_logging_info())
examples.append(example)
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_prompt_name, metric)].append(value) vals[(task_prompt_name, metric)].append(value)
# aggregate results # aggregate results
metric_results = []
for (task_prompt_name, metric), items in vals.items(): for (task_prompt_name, metric), items in vals.items():
task_name, prompt_name = task_prompt_name.split("+") task_name, prompt_name = task_prompt_name.split("+")
results[task_prompt_name]["task_name"] = task_name results[task_prompt_name]["task_name"] = task_name
results[task_prompt_name]["prompt_name"] = prompt_name results[task_prompt_name]["prompt_name"] = prompt_name
task = task_dict[task_prompt_name] task = task_dict[task_prompt_name]
results[task_prompt_name][metric] = task.aggregation()[metric](items) results[task_prompt_name][metric] = task.aggregation()[metric](items)
_metric_results = {
"task_name": task_name,
"prompt_name": prompt_name,
metric: task.aggregation()[metric](items),
**task.get_logging_info(),
}
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric( stderr = lm_eval.metrics.stderr_for_metric(
...@@ -271,8 +294,18 @@ def evaluate( ...@@ -271,8 +294,18 @@ def evaluate(
) )
if stderr is not None: if stderr is not None:
results[task_prompt_name][metric + "_stderr"] = stderr(items) results[task_prompt_name][metric + "_stderr"] = stderr(items)
_metric_results[metric + "_stderr"] = stderr(items)
return {"results": dict(results), "versions": dict(versions)} metric_results.append(_metric_results)
return {
# List of results that tracks the averages per model and prompt.
"results": metric_results,
"versions": dict(versions),
# List of all prompt x doc examples with additional information in it.
"examples": examples,
# Original results used for generating the table when running this file.
"table_results": dict(results),
}
def make_table(result_dict): def make_table(result_dict):
...@@ -293,7 +326,7 @@ def make_table(result_dict): ...@@ -293,7 +326,7 @@ def make_table(result_dict):
] ]
values = [] values = []
for k, dic in result_dict["results"].items(): for k, dic in result_dict["table_results"].items():
version = result_dict["versions"][k] version = result_dict["versions"][k]
for m, v in dic.items(): for m, v in dic.items():
if m.endswith("_stderr"): if m.endswith("_stderr"):
......
...@@ -72,6 +72,10 @@ class HFLM(BaseLM): ...@@ -72,6 +72,10 @@ class HFLM(BaseLM):
# if gpus > 1: # if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2) # self.gpt2 = nn.DataParallel(self.gpt2)
@property
def eot_token(self):
return self.tokenizer.eos_token
@property @property
def eot_token_id(self): def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
......
...@@ -2,37 +2,36 @@ import transformers ...@@ -2,37 +2,36 @@ import transformers
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from lm_eval.base import LM from lm_eval.base import BaseLM
from lm_eval import utils from lm_eval import utils
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
import math import math
class T0LM(LM): class T0LM(BaseLM):
MAX_GEN_TOKS = 256 # MAX_GEN_TOKS = 256
MAX_INP_LENGTH = 512 # MAX_INP_LENGTH = 512
VOCAB_SIZE = 32100 # VOCAB_SIZE = 32100
EOT_TOKEN_ID = 1 # EOT_TOKEN_ID = 1
def __init__(self, device='cuda', parallelize=False, pretrained='t0', batch_size=1): def __init__(self, device='cuda', parallelize=False, pretrained='t0', batch_size=1):
super().__init__() super().__init__()
if device: if device:
self.device = torch.device(device) self._device = torch.device(device)
else: else:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(pretrained)
self.t0 = transformers.AutoModelForSeq2SeqLM.from_pretrained(pretrained) self.t0 = transformers.AutoModelForSeq2SeqLM.from_pretrained(pretrained)
self.t0.eval() self.t0.eval()
if parallelize == "True": if parallelize == "True":
print(parallelize)
self.t0.parallelize() self.t0.parallelize()
self.device = torch.device('cuda:0') self._device = torch.device('cuda:0')
else: else:
self.t0.to(self.device) self.t0.to(self._device)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained) self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
self.max_length = self.MAX_INP_LENGTH # self.max_length = self.MAX_INP_LENGTH
self.batch_size = int(batch_size) self.batch_size = int(batch_size)
...@@ -42,6 +41,53 @@ class T0LM(LM): ...@@ -42,6 +41,53 @@ class T0LM(LM):
args2 = {k: v for k, v in additional_config.items() if v is not None} args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2) return cls(**args, **args2)
@property
def eot_token(self):
return self.tokenizer.eos_token
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
return self.tokenizer.model_max_length
@property
def max_gen_toks(self):
return self.tokenizer.model_max_length
@property
def batch_size(self):
# TODO: fix multi-gpu
return self._batch_size # * gpus
@property
def device(self):
# TODO: fix multi-gpu
return self._device
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def _model_call(self, inputs_tok, targets_tok):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
return self.t0(
**inputs_tok,
labels=targets_tok["input_ids"]
)
def loglikelihood(self, requests): def loglikelihood(self, requests):
res = [] res = []
for chunk in tqdm(utils.chunks(requests, self.batch_size), total=math.ceil(len(requests)/self.batch_size)): for chunk in tqdm(utils.chunks(requests, self.batch_size), total=math.ceil(len(requests)/self.batch_size)):
...@@ -62,7 +108,7 @@ class T0LM(LM): ...@@ -62,7 +108,7 @@ class T0LM(LM):
targets_tok = self.tokenizer( targets_tok = self.tokenizer(
list(targets), list(targets),
max_length=self.MAX_GEN_TOKS, max_length=self.max_gen_toks,
padding=True, padding=True,
# truncation=True, # truncation=True,
add_special_tokens=False, add_special_tokens=False,
...@@ -72,11 +118,7 @@ class T0LM(LM): ...@@ -72,11 +118,7 @@ class T0LM(LM):
for key in targets_tok: for key in targets_tok:
targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :] targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :]
with torch.no_grad(): outputs = self._model_call(inputs_tok, targets_tok)
outputs = self.t0(
**inputs_tok,
labels=targets_tok["input_ids"]
)
log_softmaxes = F.log_softmax(outputs.logits, dim=-1) log_softmaxes = F.log_softmax(outputs.logits, dim=-1)
...@@ -103,9 +145,6 @@ class T0LM(LM): ...@@ -103,9 +145,6 @@ class T0LM(LM):
res.append(answer) res.append(answer)
return res return res
def loglikelihood_rolling(self, requests):
raise NotImplementedError
def _get_stopping_criteria(self, stopping_criteria_ids): def _get_stopping_criteria(self, stopping_criteria_ids):
class MultitokenEOSCriteria(transformers.StoppingCriteria): class MultitokenEOSCriteria(transformers.StoppingCriteria):
...@@ -133,29 +172,11 @@ class T0LM(LM): ...@@ -133,29 +172,11 @@ class T0LM(LM):
EOSCriteria(self.tokenizer.eos_token) EOSCriteria(self.tokenizer.eos_token)
]) ])
def greedy_until(self, requests): def _model_generate(self, context, max_length, stopping_criteria_ids):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
res = [] return self.t0.generate(
context,
for context, until in tqdm(requests): max_length=max_length,
if isinstance(until, str): until = [until] stopping_criteria=stopping_criteria,
do_sample=False,
context_enc = self.tokenizer(context, return_tensors="pt").to(self.device).input_ids )
stopping_criteria_ids = self.tokenizer.encode(until[0])
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
cont = self.t0.generate(
context_enc,
max_length=self.MAX_GEN_TOKS,
stopping_criteria=stopping_criteria,
do_sample=False
)
s = self.tokenizer.decode(cont[0].tolist())
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return res
\ No newline at end of file
...@@ -2,39 +2,44 @@ import transformers ...@@ -2,39 +2,44 @@ import transformers
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from lm_eval.base import LM from lm_eval.base import BaseLM
from lm_eval import utils from lm_eval import utils
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
import math import math
class T5LM(LM): class T5LM(BaseLM):
MAX_GEN_TOKS = 256 # MAX_GEN_TOKS = 256
MAX_INP_LENGTH = 512 # MAX_INP_LENGTH = 512
VOCAB_SIZE = 32128 # VOCAB_SIZE = 32128
EOT_TOKEN_ID = 1 # EOT_TOKEN_ID = 1
def __init__(self, device='cuda', parallelize=False, pretrained='t5', batch_size=1): def __init__(
self,
device='cuda',
parallelize=False,
pretrained='t5',
batch_size=1
):
super().__init__() super().__init__()
if device: if device:
self.device = torch.device(device) self._device = torch.device(device)
else: else:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(pretrained)
self.t5 = transformers.AutoModelForSeq2SeqLM.from_pretrained(pretrained) self.t5 = transformers.AutoModelForSeq2SeqLM.from_pretrained(pretrained)
self.t5.eval() self.t5.eval()
if parallelize == "True": if parallelize == "True":
print(parallelize)
self.t5.parallelize() self.t5.parallelize()
self.device = torch.device('cuda:0') self._device = torch.device('cuda:0')
else: else:
self.t5.to(self.device) self.t5.to(self._device)
self.tokenizer = transformers.T5TokenizerFast.from_pretrained(pretrained) self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
self.max_length = self.MAX_INP_LENGTH # self.max_length = self.MAX_INP_LENGTH
self.batch_size = int(batch_size) self._batch_size = int(batch_size)
@classmethod @classmethod
def create_from_arg_string(cls, arg_string, additional_config={}): def create_from_arg_string(cls, arg_string, additional_config={}):
...@@ -42,6 +47,53 @@ class T5LM(LM): ...@@ -42,6 +47,53 @@ class T5LM(LM):
args2 = {k: v for k, v in additional_config.items() if v is not None} args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2) return cls(**args, **args2)
@property
def eot_token(self):
return self.tokenizer.eos_token
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
return self.tokenizer.model_max_length
@property
def max_gen_toks(self):
return self.tokenizer.model_max_length
@property
def batch_size(self):
# TODO: fix multi-gpu
return self._batch_size # * gpus
@property
def device(self):
# TODO: fix multi-gpu
return self._device
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def _model_call(self, inputs_tok, targets_tok):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
return self.t5(
**inputs_tok,
labels=targets_tok["input_ids"]
)
def loglikelihood(self, requests): def loglikelihood(self, requests):
res = [] res = []
for chunk in tqdm(utils.chunks(requests, self.batch_size), total=math.ceil(len(requests)/self.batch_size)): for chunk in tqdm(utils.chunks(requests, self.batch_size), total=math.ceil(len(requests)/self.batch_size)):
...@@ -62,7 +114,7 @@ class T5LM(LM): ...@@ -62,7 +114,7 @@ class T5LM(LM):
targets_tok = self.tokenizer( targets_tok = self.tokenizer(
list(targets), list(targets),
max_length=self.MAX_GEN_TOKS, max_length=self.max_gen_toks,
padding=True, padding=True,
# truncation=True, # truncation=True,
add_special_tokens=False, add_special_tokens=False,
...@@ -72,11 +124,7 @@ class T5LM(LM): ...@@ -72,11 +124,7 @@ class T5LM(LM):
for key in targets_tok: for key in targets_tok:
targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :] targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :]
with torch.no_grad(): outputs = self._model_call(inputs_tok, targets_tok)
outputs = self.t5(
**inputs_tok,
labels=targets_tok["input_ids"]
)
log_softmaxes = F.log_softmax(outputs.logits, dim=-1) log_softmaxes = F.log_softmax(outputs.logits, dim=-1)
...@@ -103,9 +151,6 @@ class T5LM(LM): ...@@ -103,9 +151,6 @@ class T5LM(LM):
res.append(answer) res.append(answer)
return res return res
def loglikelihood_rolling(self, requests):
raise NotImplementedError
def _get_stopping_criteria(self, stopping_criteria_ids): def _get_stopping_criteria(self, stopping_criteria_ids):
class MultitokenEOSCriteria(transformers.StoppingCriteria): class MultitokenEOSCriteria(transformers.StoppingCriteria):
...@@ -133,29 +178,11 @@ class T5LM(LM): ...@@ -133,29 +178,11 @@ class T5LM(LM):
EOSCriteria(self.tokenizer.eos_token) EOSCriteria(self.tokenizer.eos_token)
]) ])
def greedy_until(self, requests): def _model_generate(self, context, max_length, stopping_criteria_ids):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
res = [] return self.t5.generate(
context,
for context, until in tqdm(requests): max_length=max_length,
if isinstance(until, str): until = [until] stopping_criteria=stopping_criteria,
do_sample=False,
context_enc = self.tokenizer(context, return_tensors="pt").to(self.device).input_ids )
stopping_criteria_ids = self.tokenizer.encode(until[0])
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
cont = self.t5.generate(
context_enc,
max_length=self.MAX_GEN_TOKS,
stopping_criteria=stopping_criteria,
do_sample=False
)
s = self.tokenizer.decode(cont[0].tolist())
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return res
\ No newline at end of file
...@@ -53,6 +53,7 @@ from . import asdiv ...@@ -53,6 +53,7 @@ from . import asdiv
from . import gsm8k from . import gsm8k
from . import storycloze from . import storycloze
from . import hans from . import hans
from . import gem_webnlg
from . import gem_xsum from . import gem_xsum
# from . import e2e_nlg_cleaned # from . import e2e_nlg_cleaned
...@@ -109,6 +110,7 @@ TASK_REGISTRY = { ...@@ -109,6 +110,7 @@ TASK_REGISTRY = {
"wsc": superglue.SGWinogradSchemaChallenge, "wsc": superglue.SGWinogradSchemaChallenge,
# Order by benchmark/genre? # Order by benchmark/genre?
"coqa": coqa.CoQA, "coqa": coqa.CoQA,
"GEM/web_nlg": gem_webnlg.WebNLG,
"drop": drop.DROP, "drop": drop.DROP,
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze, "lambada_cloze": lambada_cloze.LAMBADA_cloze,
......
...@@ -118,25 +118,18 @@ class CoQA(PromptSourceTask): ...@@ -118,25 +118,18 @@ class CoQA(PromptSourceTask):
""" """
target = self.doc_to_target(doc).strip() target = self.doc_to_target(doc).strip()
pred = results[0].strip().split("\n")[0] pred = results[0].strip().split("\n")[0]
print("*" * 80)
print(f"DOC: {doc}")
# print(f"PS: {self.prompt.apply(doc)}")
print(f"TEXT: {self.doc_to_text(doc)}")
print(f"TARGET: {target} END TARGET")
print(f"PRED: {pred} END PRED")
print("*" * 80)
# turn_id = len(doc["questions"]["input_text"])
# gold_list = self.get_answers(doc, turn_id)
# TODO: Add HF metrics mapped from promptsource metadata.
scores = self.compute_scores([target], pred) scores = self.compute_scores([target], pred)
return { out = {
"f1": scores["f1"], "f1": scores["f1"],
"em": scores["em"], "em": scores["em"],
} }
if self.save_examples:
example = {"target": target, "pred": pred}
return out, example
return out
def higher_is_better(self): def higher_is_better(self):
return { return {
"f1": True, "f1": True,
......
from lm_eval.base import PromptSourceTask
class WebNLG(PromptSourceTask):
VERSION = 0
DATASET_PATH = "GEM/web_nlg"
DATASET_NAME = "en"
def has_training_docs(self):
return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
def stopping_criteria(self):
return '*'
def max_generation_length(self):
return 250
...@@ -277,20 +277,18 @@ class EthicsUtilitarianism(Ethics): ...@@ -277,20 +277,18 @@ class EthicsUtilitarianism(Ethics):
DATASET_NAME = "utilitarianism" DATASET_NAME = "utilitarianism"
def training_docs(self): def training_docs(self):
rnd = random.Random()
for doc in self.dataset["train"]: for doc in self.dataset["train"]:
yield self._process_doc(doc, rnd) yield self._process_doc(doc)
def validation_docs(self): def validation_docs(self):
raise NotImplementedError raise NotImplementedError
def test_docs(self): def test_docs(self):
rnd = random.Random()
for doc in self.dataset["test"]: for doc in self.dataset["test"]:
yield self._process_doc(doc, rnd) yield self._process_doc(doc)
def _process_doc(self, doc, rnd): def _process_doc(self, doc):
rnd.seed(doc["activity"]) rnd = random.Random(doc["activity"])
scenarios = [doc["activity"], doc["baseline"]] scenarios = [doc["activity"], doc["baseline"]]
ordering = [0, 1] ordering = [0, 1]
rnd.shuffle(ordering) rnd.shuffle(ordering)
......
...@@ -38,15 +38,15 @@ class Math(Task): ...@@ -38,15 +38,15 @@ class Math(Task):
return True return True
def training_docs(self): def training_docs(self):
return map(self._load_doc, self.dataset["train"]) return map(self._process_doc, self.dataset["train"])
def validation_docs(self): def validation_docs(self):
return NotImplemented return NotImplemented
def test_docs(self): def test_docs(self):
return map(self._load_doc, self.dataset["test"]) return map(self._process_doc, self.dataset["test"])
def _load_doc(self, doc): def _process_doc(self, doc):
doc["answer"] = self.remove_boxed( doc["answer"] = self.remove_boxed(
self.last_boxed_only_string(doc["solution"])) self.last_boxed_only_string(doc["solution"]))
return doc return doc
......
...@@ -76,15 +76,15 @@ class WikiText(PerplexityTask): ...@@ -76,15 +76,15 @@ class WikiText(PerplexityTask):
return True return True
def training_docs(self): def training_docs(self):
return map(self._load_doc, self.dataset["train"]) return map(self._process_doc, self.dataset["train"])
def validation_docs(self): def validation_docs(self):
return map(self._load_doc, self.dataset["validation"]) return map(self._process_doc, self.dataset["validation"])
def test_docs(self): def test_docs(self):
return map(self._load_doc, self.dataset["test"]) return map(self._process_doc, self.dataset["test"])
def _load_doc(self, doc): def _process_doc(self, doc):
return doc["page"] return doc["page"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
......
...@@ -53,9 +53,9 @@ class WinogradSchemaChallenge273(Task): ...@@ -53,9 +53,9 @@ class WinogradSchemaChallenge273(Task):
return True return True
def test_docs(self): def test_docs(self):
return map(self._load_doc, self.dataset["test"]) return map(self._process_doc, self.dataset["test"])
def _load_doc(self, doc): def _process_doc(self, doc):
# The HF implementation of `wsc273` is not `partial evaluation` friendly. # The HF implementation of `wsc273` is not `partial evaluation` friendly.
doc["text"] = doc["text"].replace(" ", " ") doc["text"] = doc["text"].replace(" ", " ")
doc["options"][0] = self.__normalize_option(doc, doc["options"][0]) doc["options"][0] = self.__normalize_option(doc, doc["options"][0])
......
...@@ -9,27 +9,29 @@ logging.getLogger("openai").setLevel(logging.WARNING) ...@@ -9,27 +9,29 @@ logging.getLogger("openai").setLevel(logging.WARNING)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True) parser.add_argument("--model", required=True)
parser.add_argument('--model_args', default="") parser.add_argument("--model_args", default="")
parser.add_argument('--tasks', default="all_tasks") parser.add_argument("--tasks", default="all_tasks")
parser.add_argument('--provide_description', action="store_true") parser.add_argument("--provide_description", action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0) parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None) parser.add_argument("--batch_size", type=int, default=None)
parser.add_argument('--device', type=str, default=None) parser.add_argument("--device", type=str, default=None)
parser.add_argument('--output_path', default=None) parser.add_argument("--output_path", default=None)
parser.add_argument('--limit', type=int, default=None) parser.add_argument("--limit", type=int, default=None)
parser.add_argument('--no_cache', action="store_true") parser.add_argument("--no_cache", action="store_true")
parser.add_argument('--description_dict_path', default=None) parser.add_argument("--description_dict_path", default=None)
parser.add_argument('--check_integrity', action="store_true") parser.add_argument("--check_integrity", action="store_true")
return parser.parse_args() return parser.parse_args()
def main(): def main():
args = parse_args() args = parse_args()
assert not args.provide_description # not implemented assert not args.provide_description # not implemented
if args.limit: if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") print(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.tasks == "all_tasks": if args.tasks == "all_tasks":
task_names = tasks.ALL_TASKS task_names = tasks.ALL_TASKS
...@@ -38,7 +40,7 @@ def main(): ...@@ -38,7 +40,7 @@ def main():
description_dict = {} description_dict = {}
if args.description_dict_path: if args.description_dict_path:
with open(args.description_dict_path, 'r') as f: with open(args.description_dict_path, "r") as f:
description_dict = json.load(f) description_dict = json.load(f)
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
...@@ -51,11 +53,12 @@ def main(): ...@@ -51,11 +53,12 @@ def main():
no_cache=args.no_cache, no_cache=args.no_cache,
limit=args.limit, limit=args.limit,
description_dict=description_dict, description_dict=description_dict,
check_integrity=args.check_integrity check_integrity=args.check_integrity,
) )
print(results)
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
print(dumped) print(dumped)
if args.output_path: if args.output_path:
......
...@@ -56,11 +56,11 @@ def main(): ...@@ -56,11 +56,11 @@ def main():
docs = join_iters(iters) docs = join_iters(iters)
description = description_dict[task_name] if description_dict and task_name in description_dict else "" description = description_dict[task_name] if description_dict and task_name in description_dict else ""
task_name = task_name.replace('/','_')
with open(os.path.join(args.output_base_path, task_name), "w") as f: with open(os.path.join(args.output_base_path, task_name), "w") as f:
for i, doc in zip(range(args.num_examples), docs) if args.num_examples > 0 else enumerate(docs): for i, doc in zip(range(args.num_examples), docs) if args.num_examples > 0 else enumerate(docs):
f.write(EXAMPLE_DIVIDER.format(i=i)) f.write(EXAMPLE_DIVIDER.format(i=i))
ctx = task.fewshot_context( ctx, _ = task.fewshot_context(
doc=doc, doc=doc,
num_fewshot=args.num_fewshot, num_fewshot=args.num_fewshot,
rnd=rnd, rnd=rnd,
......
...@@ -37,7 +37,6 @@ setuptools.setup( ...@@ -37,7 +37,6 @@ setuptools.setup(
"pycountry==20.7.3", "pycountry==20.7.3",
"numexpr==2.7.2", "numexpr==2.7.2",
"lm_dataformat==0.0.20", "lm_dataformat==0.0.20",
"pytest==6.2.3",
"pybind11==2.6.2", "pybind11==2.6.2",
"tqdm-multiprocess==0.0.11", "tqdm-multiprocess==0.0.11",
"zstandard==0.15.2", "zstandard==0.15.2",
...@@ -51,4 +50,5 @@ setuptools.setup( ...@@ -51,4 +50,5 @@ setuptools.setup(
dependency_links=[ dependency_links=[
"https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt", "https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt",
], ],
extras_require={'dev': [ 'pytest', 'black' ]}
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment