Commit 88745155 authored by cjlovering's avatar cjlovering
Browse files

Initial integration

parent 6caa0afd
import abc
from typing import Iterable
import numpy as np
import random
import re
......@@ -24,17 +25,17 @@ class LM(abc.ABC):
@abstractmethod
def loglikelihood(self, requests):
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
:param requests: list
A list of pairs (context, continuation)
context: str
Context string. Implementations of LM must be able to handle an
Context string. Implementations of LM must be able to handle an
empty context string.
continuation: str
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.
:return: list
A list of pairs (logprob, isgreedy)
......@@ -97,7 +98,7 @@ class LM(abc.ABC):
context: str
Context string
until: [str]
The string sequences to generate until. These string sequences
The string sequences to generate until. These string sequences
may each span across multiple tokens, or may be part of one token.
:return: list
A list of strings continuation
......@@ -118,7 +119,6 @@ class LM(abc.ABC):
class BaseLM(LM):
@property
@abstractmethod
def eot_token_id(self):
......@@ -145,13 +145,16 @@ class BaseLM(LM):
pass
@abstractmethod
def tok_encode(self, string: str): pass
def tok_encode(self, string: str):
pass
@abstractmethod
def tok_decode(self, tokens: Iterable[int]): pass
def tok_decode(self, tokens: Iterable[int]):
pass
@abstractmethod
def _model_generate(self, context, max_length, eos_token_id): pass
def _model_generate(self, context, max_length, eos_token_id):
pass
@abstractmethod
def _model_call(self, inps):
......@@ -187,23 +190,30 @@ class BaseLM(LM):
# TODO: automatic batch size detection for vectorization
loglikelihoods = []
for string, in tqdm(requests):
rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
)))
for (string,) in tqdm(requests):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
)
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# that
string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True)
string_nll = self._loglikelihood_tokens(
rolling_token_windows, disable_tqdm=True
)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
......@@ -223,10 +233,12 @@ class BaseLM(LM):
toks = x[1] + x[2]
return -len(toks), tuple(toks)
# TODO: automatic (variable) batch size detection for vectorization
reord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
for chunk in utils.chunks(
tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size
):
inps = []
cont_toks_list = []
inplens = []
......@@ -252,44 +264,60 @@ class BaseLM(LM):
# when too long to fit in context, truncate from the left
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length+1):][:-1],
dtype=torch.long
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long,
).to(self.device)
inplen, = inp.shape
(inplen,) = inp.shape
cont = continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length = padding_length if padding_length is not None else inplen
padding_length = (
padding_length if padding_length is not None else inplen
)
# pad length from seq to padding_length
inp = torch.cat([
inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
], dim=0)
inp = torch.cat(
[
inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long).to(
inp.device
), # [padding_length - seq]
],
dim=0,
)
inps.append(inp.unsqueeze(0)) # [1, padding_length]
cont_toks_list.append(cont)
inplens.append(inplen)
batched_inps = torch.cat(inps, dim=0) # [batch, padding_length
multi_logits = F.log_softmax(self._model_call(batched_inps), dim=-1).cpu() # [batch, padding_length, vocab]
multi_logits = F.log_softmax(
self._model_call(batched_inps), dim=-1
).cpu() # [batch, padding_length, vocab]
for (cache_key, _, _), logits, inp, inplen, cont_toks \
in zip(chunk, multi_logits, inps, inplens, cont_toks_list):
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
chunk, multi_logits, inps, inplens, cont_toks_list
):
# Slice to original seq length
contlen = len(cont_toks)
logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab]
logits = logits[inplen - contlen : inplen].unsqueeze(
0
) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1)
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0) # [1, seq]
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
0
) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal))
......@@ -301,9 +329,9 @@ class BaseLM(LM):
res.append(answer)
return reord.get_original(res)
def greedy_until(self, requests):
# TODO: implement fully general `until` that handles untils that are
# TODO: implement fully general `until` that handles untils that are
# multiple tokens or that span multiple tokens correctly
# TODO: extract to TokenizedLM?
......@@ -312,29 +340,33 @@ class BaseLM(LM):
def _collate(x):
toks = self.tok_encode(x[0])
return len(toks), x[0]
reord = utils.Reorderer(requests, _collate)
for context, until in tqdm(reord.get_reordered()):
if isinstance(until, str):
until = [until]
primary_until, = self.tok_encode(until[0])
context_enc = torch.tensor([self.tok_encode(context)[self.max_gen_toks - self.max_length:]]).to(self.device)
(primary_until,) = self.tok_encode(until[0])
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device)
cont = self._model_generate(context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until)
cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until
)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:])
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
for term in until:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return reord.get_original(res)
......@@ -383,7 +415,7 @@ class Task(abc.ABC):
self._fewshot_docs = None
def download(self, data_dir=None, cache_dir=None, download_mode=None):
""" Downloads and returns the task dataset.
"""Downloads and returns the task dataset.
Override this method to download the dataset from a custom API.
:param data_dir: str
......@@ -412,7 +444,7 @@ class Task(abc.ABC):
name=self.DATASET_NAME,
data_dir=data_dir,
cache_dir=cache_dir,
download_mode=download_mode
download_mode=download_mode,
)
@abstractmethod
......@@ -478,22 +510,22 @@ class Task(abc.ABC):
@abstractmethod
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
pass
@abstractmethod
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
......@@ -507,7 +539,7 @@ class Task(abc.ABC):
def aggregation(self):
"""
:returns: {str: [metric_score] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metric scores
"""
pass
......@@ -516,22 +548,26 @@ class Task(abc.ABC):
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
pass
def fewshot_description(self):
import warnings
warnings.warn(
"`fewshot_description` will be removed in futures versions. Pass "
"any custom descriptions to the `evaluate` function instead.",
DeprecationWarning)
DeprecationWarning,
)
return ""
@utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
""" Returns a fewshot context string that is made up of a prepended description
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
......@@ -548,7 +584,9 @@ class Task(abc.ABC):
:returns: str
The fewshot context.
"""
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`"
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
......@@ -556,7 +594,9 @@ class Task(abc.ABC):
)
if provide_description is not None:
# nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description = description + "\n\n" if description else ""
......@@ -569,7 +609,9 @@ class Task(abc.ABC):
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs() if self.has_validation_docs() else self.test_docs()
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
......@@ -577,23 +619,90 @@ class Task(abc.ABC):
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = "\n\n".join(
[self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex]
) + "\n\n"
labeled_examples = (
"\n\n".join(
[
self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ "\n\n"
)
example = self.doc_to_text(doc)
return description + labeled_examples + example
class MultipleChoiceTask(Task):
class PromptSourceTask(Task):
def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=None):
super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt
def doc_to_target(self, doc):
return " " + doc['choices'][doc['gold']]
_, target = prompt.apply(doc)
return f" {target}"
def doc_to_text(self, doc):
text, _ = prompt.apply(doc)
return text
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
_requests = []
if self.prompt.metadata.choices_in_prompt:
for answer_choice in prompt.get_fixed_answer_choices_list():
ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}")
_requests.append(ll_answer_choice)
else:
# TODO(Albert): What is the stop symbol? Is it model specific?
ll_greedy, _ = rf.greedy_until(ctx, ["\nQ:"])
_requests.append(ll_greedy)
return _requests
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
raise NotImplementedError(
"Implement process results using the `prompt.metadata.metrics`. See below."
)
if self.prompt.metadata.choices_in_prompt:
for result, answer_choice in zip(
prompt.get_fixed_answer_choices_list(), results
):
pass
else:
continuation = results
# Map metric name to HF metric.
# TODO(Albert): What is Other?
metric_names = prompt.metadata.metrics
class MultipleChoiceTask(Task):
def doc_to_target(self, doc):
return " " + doc["choices"][doc["gold"]]
def construct_requests(self, doc, ctx):
lls = [
rf.loglikelihood(ctx, " {}".format(choice))[0]
for choice in doc['choices']
rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
]
return lls
......@@ -601,21 +710,21 @@ class MultipleChoiceTask(Task):
def process_results(self, doc, results):
gold = doc["gold"]
acc = 1. if np.argmax(results) == gold else 0.
acc = 1.0 if np.argmax(results) == gold else 0.0
completion_len = np.array([float(len(i)) for i in doc["choices"]])
acc_norm = 1. if np.argmax(results / completion_len) == gold else 0.
acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
return {
"acc": acc,
"acc_norm": acc_norm,
}
def higher_is_better(self):
return {
"acc": True,
"acc_norm": True,
}
def aggregation(self):
return {
"acc": mean,
......@@ -624,7 +733,6 @@ class MultipleChoiceTask(Task):
class PerplexityTask(Task, abc.ABC):
def has_training_docs(self):
return False
......@@ -632,9 +740,15 @@ class PerplexityTask(Task, abc.ABC):
assert k == 0
return []
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, "The number of fewshot examples must be 0 for perplexity tasks."
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`."
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert (
num_fewshot == 0
), "The number of fewshot examples must be 0 for perplexity tasks."
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`."
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
......@@ -642,7 +756,9 @@ class PerplexityTask(Task, abc.ABC):
)
if provide_description is not None:
# nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
return ""
......@@ -665,7 +781,7 @@ class PerplexityTask(Task, abc.ABC):
return req
def process_results(self, doc, results):
loglikelihood, = results
(loglikelihood,) = results
words = self.count_words(doc)
bytes_ = self.count_bytes(doc)
return {
......@@ -687,23 +803,23 @@ class PerplexityTask(Task, abc.ABC):
@classmethod
def count_words(cls, doc):
""" Downstream tasks with custom word boundaries should override this! """
"""Downstream tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc))
def hash_args(attr, args):
dat = json.dumps([attr] + list(args))
return hashlib.sha256(dat.encode('utf-8')).hexdigest()
return hashlib.sha256(dat.encode("utf-8")).hexdigest()
class CacheHook:
def __init__(self, cachinglm):
if cachinglm is None:
if cachinglm is None:
self.dbdict = None
return
self.dbdict = cachinglm.dbdict
def add_partial(self, attr, req, res):
if self.dbdict is None:
return
......@@ -733,7 +849,7 @@ class CachingLM:
def fn(requests):
res = []
remaining_reqs = []
# figure out which ones are cached and which ones are new
for req in requests:
hsh = hash_args(attr, req)
......@@ -746,7 +862,7 @@ class CachingLM:
else:
res.append(None)
remaining_reqs.append(req)
# actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs)
......@@ -764,41 +880,48 @@ class CachingLM:
self.dbdict.commit()
return res
return fn
def get_cache_hook(self):
return CacheHook(self)
REQUEST_RETURN_LENGTHS = {
'loglikelihood': 2,
'greedy_until': None,
'loglikelihood_rolling': None,
"loglikelihood": 2,
"greedy_until": None,
"loglikelihood_rolling": None,
}
class Request:
def __init__(self, request_type, args, index=None):
if request_type not in REQUEST_RETURN_LENGTHS.keys():
raise NotImplementedError('The request type {} is not implemented!'.format(request_type))
raise NotImplementedError(
"The request type {} is not implemented!".format(request_type)
)
self.request_type = request_type
self.args = args
self.index = index
def __iter__(self):
if REQUEST_RETURN_LENGTHS[self.request_type] is None:
raise IndexError('This request type does not return multiple arguments!')
raise IndexError("This request type does not return multiple arguments!")
for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
yield Request(self.request_type, self.args, i)
def __getitem__(self, i):
if REQUEST_RETURN_LENGTHS[self.request_type] is None:
raise IndexError('This request type does not return multiple arguments!')
raise IndexError("This request type does not return multiple arguments!")
return Request(self.request_type, self.args, i)
def __eq__(self, other):
return self.request_type == other.request_type and self.args == other.args and self.index == other.index
return (
self.request_type == other.request_type
and self.args == other.args
and self.index == other.index
)
def __repr__(self):
return f"Req_{self.request_type}{self.args}[{self.index}]\n"
......@@ -808,6 +931,7 @@ class RequestFactory:
def __getattr__(self, attr):
def fn(*args):
return Request(attr, args)
return fn
......
......@@ -6,21 +6,33 @@ import lm_eval.metrics
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
import promptsource
import numpy as np
from promptsource.templates import DatasetTemplates
from lm_eval.utils import positional_deprecated, run_task_tests
@positional_deprecated
def simple_evaluate(model, model_args=None, tasks=[],
num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000,
description_dict=None, check_integrity=False):
def simple_evaluate(
model,
model_args=None,
tasks=[],
num_fewshot=0,
batch_size=None,
device=None,
no_cache=False,
limit=None,
bootstrap_iters=100000,
description_dict=None,
check_integrity=False,
):
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
Name of model or LM object, see lm_eval.models.get_model
:param model_args: Optional[str]
String arguments for each model class, see LM.create_from_arg_string.
String arguments for each model class, see LM.create_from_arg_string.
Ignored if `model` argument is a LM object.
:param tasks: list[Union[str, Task]]
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
......@@ -37,7 +49,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
Dictionary of custom task descriptions of the form: `task_name: description`
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
:return
......@@ -49,20 +61,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
assert tasks != [], "No tasks specified"
if isinstance(model, str):
if model_args is None: model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, {
'batch_size': batch_size, 'device': device
})
if model_args is None:
model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string(
model_args, {"batch_size": batch_size, "device": device}
)
else:
assert isinstance(model, lm_eval.base.LM)
lm = model
if not no_cache:
lm = lm_eval.base.CachingLM(
lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db'
lm,
"lm_cache/"
+ model
+ "_"
+ model_args.replace("=", "-").replace(",", "_").replace("/", "-")
+ ".db",
)
task_dict = lm_eval.tasks.get_task_dict(tasks)
task_dict = lm_eval.tasks.get_task_dict_promptsource(tasks)
if check_integrity:
run_task_tests(task_list=tasks)
......@@ -72,7 +90,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict=task_dict,
num_fewshot=num_fewshot,
limit=limit,
description_dict=description_dict
description_dict=description_dict,
)
# add info about the model and few shot config
......@@ -85,14 +103,22 @@ def simple_evaluate(model, model_args=None, tasks=[],
"no_cache": no_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
"description_dict": description_dict
"description_dict": description_dict,
}
return results
@positional_deprecated
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None):
def evaluate(
lm,
task_dict,
provide_description=None,
num_fewshot=0,
limit=None,
bootstrap_iters=100000,
description_dict=None,
):
"""Instantiate and evaluate a model on a list of tasks.
:param lm: obj
......@@ -108,7 +134,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
Dictionary of custom task descriptions of the form: `task_name: description`
:return
Dictionary of results
"""
......@@ -118,12 +144,14 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
assert not provide_description # not implemented.
if provide_description is not None:
# nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
task_dict_items = [
(name, task)
for name, task in task_dict.items()
if(task.has_validation_docs() or task.has_test_docs())
if (task.has_validation_docs() or task.has_test_docs())
]
results = collections.defaultdict(dict)
......@@ -158,15 +186,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
rnd.seed(42)
rnd.shuffle(task_docs)
description = description_dict[task_name] if description_dict and task_name in description_dict else ""
description = (
description_dict[task_name]
if description_dict and task_name in description_dict
else ""
)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
rnd=rnd,
description=description
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)):
......@@ -189,11 +218,13 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print("Running", reqtype, "requests")
resps = getattr(lm, reqtype)([req.args for req in reqs])
resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]
resps = [
x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
]
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp))
vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task
......@@ -207,25 +238,29 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
metrics = task.process_results(doc, requests)
for metric, value in metrics.items():
vals[(task_name, metric)].append(value)
task_name, prompt_name = task_name.split("+")
results[task_name]["task_name"] = task_name
results[task_name]["prompt_name"] = prompt_name
# aggregate results
for (task_name, metric), items in vals.items():
task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric(
metric=task.aggregation()[metric],
bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters,
bootstrap_iters=min(bootstrap_iters, 1000)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
)
if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items)
return {
"results": dict(results),
"versions": dict(versions)
}
return {"results": dict(results), "versions": dict(versions)}
def make_table(result_dict):
......@@ -247,9 +282,9 @@ def make_table(result_dict):
if m + "_stderr" in dic:
se = dic[m + "_stderr"]
values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se])
values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
else:
values.append([k, version, m, '%.4f' % v, '', ''])
values.append([k, version, m, "%.4f" % v, "", ""])
k = ""
version = ""
md_writer.value_matrix = values
......
from promptsource.templates import DatasetTemplates
from pprint import pprint
from typing import List, Union
......@@ -58,8 +60,8 @@ from . import storycloze
# 6 total
gpt3_translation_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian
"wmt14": ["en-fr", "fr-en"], # French
"wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian
}
......@@ -67,7 +69,7 @@ gpt3_translation_benchmarks = {
selected_translation_benchmarks = {
**gpt3_translation_benchmarks,
"wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt17": ['en-ar', 'ar-en'] # Arabic
"iwslt17": ["en-ar", "ar-en"], # Arabic
}
# 319 total
......@@ -91,7 +93,7 @@ TASK_REGISTRY = {
"rte": glue.RTE,
"qnli": glue.QNLI,
"qqp": glue.QQP,
#"stsb": glue.STSB, # not implemented yet
# "stsb": glue.STSB, # not implemented yet
"sst": glue.SST,
"wnli": glue.WNLI,
# SuperGLUE
......@@ -102,34 +104,26 @@ TASK_REGISTRY = {
"record": superglue.ReCoRD,
"wic": superglue.WordsInContext,
"wsc": superglue.SGWinogradSchemaChallenge,
# Order by benchmark/genre?
"coqa": coqa.CoQA,
"drop": drop.DROP,
"lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze,
# multilingual lambada
**lambada_multilingual.construct_tasks(),
"wikitext": wikitext.WikiText,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa": piqa.PiQA,
"prost": prost.PROST,
"mc_taco": mc_taco.MCTACO,
# Science related
"pubmedqa" : pubmedqa.Pubmed_QA,
"sciq" : sciq.SciQ,
"pubmedqa": pubmedqa.Pubmed_QA,
"sciq": sciq.SciQ,
"qasper": qasper.QASPER,
"qa4mre_2011" : qa4mre.QA4MRE_2011,
"qa4mre_2012" : qa4mre.QA4MRE_2012,
"qa4mre_2013" : qa4mre.QA4MRE_2013,
"qa4mre_2011": qa4mre.QA4MRE_2011,
"qa4mre_2012": qa4mre.QA4MRE_2012,
"qa4mre_2013": qa4mre.QA4MRE_2013,
"triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge,
......@@ -140,7 +134,7 @@ TASK_REGISTRY = {
"squad2": squad.SQuAD2,
"race": race.RACE,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet
"headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es
"headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es
"headqa_es": headqa.HeadQAEs,
"headqa_en": headqa.HeadQAEn,
"mathqa": mathqa.MathQA,
......@@ -150,21 +144,17 @@ TASK_REGISTRY = {
"anli_r1": anli.ANLIRound1,
"anli_r2": anli.ANLIRound2,
"anli_r3": anli.ANLIRound3,
"ethics_cm": hendrycks_ethics.EthicsCM,
"ethics_deontology": hendrycks_ethics.EthicsDeontology,
"ethics_justice": hendrycks_ethics.EthicsJustice,
"ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
"ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
"ethics_virtue": hendrycks_ethics.EthicsVirtue,
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
"truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
"truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# dialogue
"mutual": mutual.MuTual,
"mutual_plus": mutual.MuTualPlus,
# math
"math_algebra": hendrycks_math.MathAlgebra,
"math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
......@@ -175,7 +165,6 @@ TASK_REGISTRY = {
"math_precalc": hendrycks_math.MathPrecalculus,
"math_asdiv": asdiv.Asdiv,
"gsm8k": gsm8k.GradeSchoolMath8K,
# arithmetic
"arithmetic_2da": arithmetic.Arithmetic2DPlus,
"arithmetic_2ds": arithmetic.Arithmetic2DMinus,
......@@ -189,22 +178,18 @@ TASK_REGISTRY = {
"arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks)
**hendrycks_test.create_all_tasks(),
# e.g. wmt14-fr-en
**translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# chef's selection, mostly wmt20
**translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
# Word Scrambling and Manipulation Tasks
"anagrams1": unscramble.Anagrams1,
"anagrams2": unscramble.Anagrams2,
"cycle_letters": unscramble.CycleLetters,
"random_insertion": unscramble.RandomInsertion,
"reversed_words": unscramble.ReversedWords,
# Pile
"pile_arxiv": pile.PileArxiv,
"pile_books3": pile.PileBooks3,
......@@ -228,7 +213,6 @@ TASK_REGISTRY = {
"pile_ubuntu-irc": pile.PileUbuntuIrc,
"pile_wikipedia": pile.PileWikipedia,
"pile_youtubesubtitles": pile.PileYoutubeSubtitles,
# BLiMP
"blimp_adjunct_island": blimp.BlimpAdjunctIsland,
"blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
......@@ -297,7 +281,6 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
......@@ -321,19 +304,43 @@ def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items():
if class_ is task_object:
return name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__
return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]):
task_name_dict = {
task_name: get_task(task_name)()
for task_name in task_name_list if isinstance(task_name, str)
for task_name in task_name_list
if isinstance(task_name, str)
}
task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object
for task_object in task_name_list if not isinstance(task_object, str)
for task_object in task_name_list
if not isinstance(task_object, str)
}
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {**task_name_dict, **task_name_from_object_dict}
def get_task_dict_promptsource(task_name_list: List[str]):
"""Loads a task instance for each prompt written for that task."""
task_name_dict = {}
for task_name in task_name_list:
assert isinstance(task_name, str)
task_prompts = DatasetTemplates(task_name)
for prompt_name in task_prompts.all_template_names:
prompt = task_prompts[prompt_name]
# NOTE: We choose a sep that can be easily split.
task_name_dict[f"{task_name}+{prompt_name}"] = get_task(task_name)(
prompt=prompt
)
return task_name_dict
......@@ -51,44 +51,22 @@ class CoQA(Task):
def test_docs(self):
pass
def doc_to_text(self, doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + '\n\n'
for (q, a) in zip_longest(doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]): # omit target answer ai
question = f"Q: {q}\n\n"
answer = f"A: {a}\n\n" if a is not None else "A:"
doc_text += question + answer
return doc_text
@classmethod
def get_answers(cls, doc, turn_id):
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = []
answer_forturn = doc["answers"]["input_text"][turn_id - 1]
answers.append(answer_forturn)
additional_answers = doc.get("additional_answers")
if additional_answers:
for key in additional_answers:
additional_answer_for_turn = additional_answers[key]["input_text"][turn_id - 1]
if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn)
return answers
@classmethod
def get_answer_choice(self, raw_text):
# Function maps answers to CoQA answer categories
# ~ 1/5 of the CoQA answers are Yes/No
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
if raw_text == "unknown":
return '0'
if squad_metrics.normalize_answer(raw_text) == "yes":
return '1'
if squad_metrics.normalize_answer(raw_text) == "no":
return '2'
return '3' # Not a yes/no question
# @classmethod
# def get_answers(cls, doc, turn_id):
# # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
# answers = []
# answer_forturn = doc["answers"]["input_text"][turn_id - 1]
# answers.append(answer_forturn)
# additional_answers = doc.get("additional_answers")
# if additional_answers:
# for key in additional_answers:
# additional_answer_for_turn = additional_answers[key]["input_text"][
# turn_id - 1
# ]
# if additional_answer_for_turn.lower() not in map(str.lower, answers):
# answers.append(additional_answer_for_turn)
# return answers
@staticmethod
def compute_scores(gold_list, pred):
......@@ -98,40 +76,38 @@ class CoQA(Task):
em_sum = 0.0
if len(gold_list) > 1:
for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1:]
gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
em_sum += max(
squad_metrics.compute_exact(a, pred) for a in gold_answers
)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)
return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))}
def doc_to_target(self, doc, turnid=None):
# Default to prediction of last turn.
if turnid is None:
turnid = len(doc["questions"]["input_text"])
raw_text = doc['answers']["input_text"][turnid - 1]
return " " + raw_text
return {
"em": em_sum / max(1, len(gold_list)),
"f1": f1_sum / max(1, len(gold_list)),
}
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
cont_request = rf.greedy_until(ctx, ['\nQ:'])
cont_request = rf.greedy_until(ctx, ["\nQ:"])
return cont_request
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
......@@ -139,15 +115,18 @@ class CoQA(Task):
:param results:
The results of the requests created in construct_requests.
"""
turn_id = len(doc["questions"]["input_text"])
gold_list = self.get_answers(doc, turn_id)
pred = results[0].strip().split('\n')[0]
target = self.doc_to_target(doc).strip()
pred = results[0].strip().split("\n")[0]
# turn_id = len(doc["questions"]["input_text"])
# gold_list = self.get_answers(doc, turn_id)
scores = self.compute_scores(gold_list, pred)
# TODO: Add HF metrics mapped from promptsource metadata.
scores = self.compute_scores([target], pred)
return {
"f1": scores['f1'],
"em": scores['em'],
"f1": scores["f1"],
"em": scores["em"],
}
def higher_is_better(self):
......
......@@ -70,21 +70,26 @@ class DROP(Task):
@classmethod
def get_answers(cls, qa):
def _flatten_validated_answers(validated_answers):
""" Flattens a dict of lists of validated answers.
"""Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
vas = []
for i in range(len(validated_answers["number"])):
vas.append({
"number": validated_answers["number"][i],
"date": validated_answers["date"][i],
"spans": validated_answers["spans"][i],
})
vas.append(
{
"number": validated_answers["number"][i],
"date": validated_answers["date"][i],
"spans": validated_answers["spans"][i],
}
)
return vas
answers = []
answers_set = set()
candidates = [qa["answer"]] + _flatten_validated_answers(qa["validated_answers"])
candidates = [qa["answer"]] + _flatten_validated_answers(
qa["validated_answers"]
)
for candidate in candidates:
answer = cls.parse_answer(candidate)
if answer in answers_set:
......@@ -100,15 +105,17 @@ class DROP(Task):
return (str(answer["number"]),)
if answer["spans"] != []:
return tuple(answer["spans"])
return (" ".join([answer["date"]["day"],
answer["date"]["month"],
answer["date"]["year"]]).strip(),)
return (
" ".join(
[answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]]
).strip(),
)
def doc_to_text(self, doc):
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
# def doc_to_text(self, doc):
# return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
def doc_to_target(self, doc):
return " " + ", ".join(doc["answers"][0])
# def doc_to_target(self, doc):
# return " " + ", ".join(doc["answers"][0])
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
......@@ -134,7 +141,13 @@ class DROP(Task):
:param results:
The results of the requests created in construct_requests.
"""
preds, golds = results, doc["answers"]
pred = results[0].strip()
target = self.doc_to_target(doc).strip()
preds = [pred]
golds = [target]
max_em = 0
max_f1 = 0
for gold_answer in golds:
......@@ -142,10 +155,7 @@ class DROP(Task):
if gold_answer[0].strip():
max_em = max(max_em, exact_match)
max_f1 = max(max_f1, f1_score)
return {
"em": max_em,
"f1": max_f1
}
return {"em": max_em, "f1": max_f1}
def get_metrics(self, predicted, gold):
"""
......@@ -158,7 +168,9 @@ class DROP(Task):
predicted_bags = self._answer_to_bags(predicted)
gold_bags = self._answer_to_bags(gold)
if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]):
if set(predicted_bags[0]) == set(gold_bags[0]) and len(
predicted_bags[0]
) == len(gold_bags[0]):
exact_match = 1.0
else:
exact_match = 0.0
......@@ -190,7 +202,9 @@ class DROP(Task):
for gold_index, gold_item in enumerate(gold):
for pred_index, pred_item in enumerate(predicted):
if self._match_numbers_if_present(gold_item, pred_item):
scores[gold_index, pred_index] = self._compute_f1(pred_item, gold_item)
scores[gold_index, pred_index] = self._compute_f1(
pred_item, gold_item
)
row_ind, col_ind = linear_sum_assignment(-scores)
max_scores = np.zeros([max(len(gold), len(predicted))])
......@@ -256,7 +270,11 @@ class DROP(Task):
def _normalize(self, answer):
tokens = [
self._white_space_fix(self._remove_articles(self._fix_number(self._remove_punc(token.lower()))))
self._white_space_fix(
self._remove_articles(
self._fix_number(self._remove_punc(token.lower()))
)
)
for token in self._tokenize(answer)
]
tokens = [token for token in tokens if token.strip()]
......@@ -269,10 +287,7 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"em": mean,
"f1": mean
}
return {"em": mean, "f1": mean}
def higher_is_better(self):
"""
......@@ -280,7 +295,4 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"em": True,
"f1": True
}
return {"em": True, "f1": True}
......@@ -40,7 +40,7 @@ class RACE(Task):
DATASET_NAME = "high"
cache = {}
letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
letter_to_num = {"A": 0, "B": 1, "C": 2, "D": 3}
def has_training_docs(self):
return True
......@@ -59,17 +59,27 @@ class RACE(Task):
# is shown that one document is made per passage.
r = collections.defaultdict(list)
for item in datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)[set]:
r[item['article']].append(item)
res = list(r.values() >> each(lambda x: {
'article': x[0]['article'],
'problems': x >> each(lambda y: {
'question': y['question'],
'answer': y['answer'],
'options': y['options'],
})
}))
for item in datasets.load_dataset(
path=self.DATASET_PATH, name=self.DATASET_NAME
)[set]:
r[item["article"]].append(item)
res = list(
r.values()
>> each(
lambda x: {
"article": x[0]["article"],
"problems": x
>> each(
lambda y: {
"question": y["question"],
"answer": y["answer"],
"options": y["options"],
}
),
}
)
)
self.cache[set] = res
return res
......@@ -85,49 +95,48 @@ class RACE(Task):
@classmethod
def get_answer_option(cls, problem):
answer = cls.letter_to_num[problem['answer']]
return problem['options'][answer]
answer = cls.letter_to_num[problem["answer"]]
return problem["options"][answer]
@classmethod
def last_problem(cls, doc):
return doc['problems'][-1]
def doc_to_text(self, doc):
text = 'Article: ' + doc['article'] + '\n\n'
for problem in doc['problems'][:-1]:
if problem['question'][-6:] == ' _ .':
text += problem['question'][-5:] + self.get_answer_option(problem) + '\n'
else:
question = 'Question: ' + problem['question'] + '\n'
answer = 'Answer: ' + self.get_answer_option(problem) + '\n'
text += question + answer
text += self.last_problem(doc)['question']
return text
def doc_to_target(self, doc):
return " " + self.get_answer_option(self.last_problem(doc))
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
problem = self.last_problem(doc)
ll_choices = [
rf.loglikelihood(ctx, " " + problem['options'][i])[0]
for i in range(4)
]
return ll_choices
return doc["problems"][-1]
# def doc_to_text(self, doc):
# text = 'Article: ' + doc['article'] + '\n\n'
# for problem in doc['problems'][:-1]:
# if problem['question'][-6:] == ' _ .':
# text += problem['question'][-5:] + self.get_answer_option(problem) + '\n'
# else:
# question = 'Question: ' + problem['question'] + '\n'
# answer = 'Answer: ' + self.get_answer_option(problem) + '\n'
# text += question + answer
# text += self.last_problem(doc)['question']
# return text
# def doc_to_target(self, doc):
# return " " + self.get_answer_option(self.last_problem(doc))
# def construct_requests(self, doc, ctx):
# """Uses RequestFactory to construct Requests and returns an iterable of
# Requests which will be sent to the LM.
# :param doc:
# The document as returned from training_docs, validation_docs, or test_docs.
# :param ctx: str
# The context string, generated by fewshot_context. This includes the natural
# language description, as well as the few shot examples, and the question
# part of the document for `doc`.
# """
# problem = self.last_problem(doc)
# ll_choices = [
# rf.loglikelihood(ctx, " " + problem["options"][i])[0] for i in range(4)
# ]
# return ll_choices
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
......@@ -135,28 +144,24 @@ class RACE(Task):
:param results:
The results of the requests created in construct_requests.
"""
gold = self.letter_to_num[self.last_problem(doc)['answer']]
#
gold = self.letter_to_num[self.doc_to_target(doc)]
# gold = self.letter_to_num[self.last_problem(doc)["answer"]]
pred = np.argmax(results)
return {
"acc": int(pred == gold)
}
return {"acc": int(pred == gold)}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"acc": mean
}
return {"acc": mean}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"acc": True
}
return {"acc": True}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment