Commit e1ae8a2f authored by Herbie Bradley's avatar Herbie Bradley
Browse files

Merge remote-tracking branch 'origin/big-refactor' into calibration

parents 50e99bd7 30936bc7
......@@ -18,6 +18,7 @@ import lm_eval.tasks
from lm_eval.logger import eval_logger
from lm_eval.utils import (
create_iterator,
eval_logger,
get_git_commit_hash,
make_table,
positional_deprecated,
......@@ -220,15 +221,17 @@ def evaluate(
task_hierarchy = collections.defaultdict(list)
# store the ordering of tasks and groups
task_order = collections.defaultdict(int)
# store the aggregation for aggregating across tasks in the same group
sample_agg_fn = collections.defaultdict(dict)
task_group_alias = collections.defaultdict(dict)
# get lists of each type of request
for task_name, task in task_dict.items():
if type(task) == tuple:
group_name, task = task
task_hierarchy[group_name].append(task_name)
versions[group_name] = "N/A"
else:
group_name = None
task_hierarchy[task_name] = []
if task is None:
......@@ -237,6 +240,16 @@ def evaluate(
versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config())
if "task_alias" in configs[task_name]:
task_group_alias[task_name] = configs[task_name]["task_alias"]
if (
("group_alias" in configs[task_name])
and (group_name not in task_group_alias)
and (group_name is not None)
):
task_group_alias[group_name] = configs[task_name]["group_alias"]
if limit is not None:
if task.has_test_docs():
task_docs = task.test_docs()
......@@ -248,7 +261,7 @@ def evaluate(
task.build_all_requests(limit=limit, rank=lm.rank, world_size=lm.world_size)
eval_logger.info(
eval_logger.debug(
f"Task: {task_name}; number of requests on this rank: {len(task.instances)}"
)
......@@ -263,12 +276,9 @@ def evaluate(
eval_logger.info(f"Request: {str(inst)}")
# aggregate Instances by LM method requested to get output.
reqtype = (
"loglikelihood"
if task.OUTPUT_TYPE == "multiple_choice"
else task.OUTPUT_TYPE
) # TODO: this is hacky, fix in task.py
requests[reqtype].extend(task.instances)
for instance in task.instances:
reqtype = instance.request_type
requests[reqtype].append(instance)
if lm.world_size > 1:
instances_rnk = torch.tensor(len(task._instances), device=lm.device)
......@@ -520,30 +530,15 @@ def evaluate(
group_name = None
agg_fn = task.aggregation()[metric]
task_score = agg_fn(items)
if group_name is not None:
sample_metric_key = metric + "(sample agg)," + key
for grouping in task_to_group[task_name]:
if metric_key in results[grouping]:
results[grouping][metric_key].append(task_score)
else:
results[grouping][metric_key] = [task_score]
if sample_metric_key in results[grouping]:
results[grouping][sample_metric_key] += items
else:
results[grouping][sample_metric_key] = items.copy()
sample_agg_fn[grouping][sample_metric_key] = agg_fn
results[task_name][metric_key] = task_score
results[task_name][metric_key] = agg_fn(items)
results[task_name]["samples"] = len(items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
if False: # bootstrap_iters > 0:
if bootstrap_iters > 0:
stderr = lm_eval.api.metrics.stderr_for_metric(
metric=task.aggregation()[metric],
bootstrap_iters=min(bootstrap_iters, 1000)
bootstrap_iters=min(bootstrap_iters, 100)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
)
......@@ -552,33 +547,126 @@ def evaluate(
results[task_name][metric + "_stderr" + "," + key] = stderr(items)
if bool(results):
for task_or_group in results.keys():
for metric in results[task_or_group].keys():
if type(results[task_or_group][metric]) == list:
if "(sample agg)" in metric:
results[task_or_group][metric] = sample_agg_fn[
task_or_group
][metric](results[task_or_group][metric])
for group, task_list in reversed(task_hierarchy.items()):
if task_list == []:
total_size = results[group]["samples"]
else:
total_size = 0
for task in task_list:
metrics = results[task]
current_size = metrics.pop("samples")
# TODO: There should be a way for users
# to toggle between weighted and
# unweighted averaging
# For unweighted averaging, use:
# current_size = 1
all_stderr = []
for metric in [
key for key in metrics.keys() if "_stderr" not in key
]:
stderr = "_stderr,".join(metric.split(","))
stderr_score = results[task][stderr]
var_score = stderr_score**2
metric_score = results[task][metric]
all_stderr.append(stderr)
if metric in results[group]:
results[group][metric] = (
results[group][metric] * total_size
+ metric_score * current_size
) / (total_size + current_size)
# $$s_z^2 = \frac{(n-1) s_x^2 + (m-1) s_y^2}{n+m-1} + \frac{nm(\bar x - \bar y)^2}{(n+m)(n+m-1)}.$$
results[group][stderr] = (
(total_size - 1) * results[group][stderr]
+ (current_size - 1) * var_score
) / (
total_size + current_size - 1
) + total_size * current_size / (
(total_size + current_size)
* (total_size + current_size - 1)
) * (
results[group][metric] - metric_score
) ** 2
else:
results[group][metric] = metric_score
results[group][stderr] = var_score
total_size += current_size
for stderr in all_stderr:
results[group][stderr] = np.sqrt(results[group][stderr])
results[group]["samples"] = total_size
def print_tasks(task_hierarchy, task_order, task_version, task_group_alias):
results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict)
for group_name, task_list in task_hierarchy.items():
order = task_order[group_name]
results_agg[group_name] = results[group_name].copy()
results_agg[group_name]["tab"] = order
if (order < max(task_order.values())) and (len(task_list) > 0):
groups_agg[group_name] = results[group_name].copy()
groups_agg[group_name]["tab"] = order
if task_list != []:
for task in sorted(task_list):
if task in task_hierarchy:
_task_hierarchy = {task: task_hierarchy[task]}
else:
results[task_or_group][metric] = np.average(
results[task_or_group][metric]
)
versions[task_or_group] = "N/A"
_task_hierarchy = {task: []}
for task_name, task in task_dict.items():
if type(task) == tuple:
group_name, task = task
order = task_order[group_name]
tabbed_name = "-" * order + group_name
results_agg[tabbed_name] = results[group_name]
versions[tabbed_name] = versions[group_name]
if order == 0:
groups_agg[group_name] = results[group_name]
order = task_order[task_name]
tabbed_name = "-" * order + task_name
results_agg[tabbed_name] = results[task_name]
versions[tabbed_name] = versions[task_name]
_results_agg, _groups_agg, task_version = print_tasks(
_task_hierarchy, task_order, task_version, task_group_alias
)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
return results_agg, groups_agg, task_version
results_agg, groups_agg, versions = print_tasks(
task_hierarchy, task_order, versions, task_group_alias
)
for task in results_agg:
task_results = results_agg[task]
if "samples" in task_results:
task_results.pop("samples")
tab_string = ""
if "tab" in task_results:
tab = task_results.pop("tab")
tab_string = " " * tab + "- " if tab > 0 else ""
if task in task_group_alias:
task_alias = task_group_alias[task]
results_agg[task]["alias"] = tab_string + task_alias
else:
results_agg[task]["alias"] = tab_string + task
for group in groups_agg:
group_results = groups_agg[group]
if "samples" in group_results:
group_results.pop("samples")
tab_string = ""
if "tab" in group_results:
tab = group_results.pop("tab")
tab_string = " " * tab + "- " if tab > 0 else ""
if group in task_group_alias:
group_alias = task_group_alias[group]
groups_agg[group]["alias"] = tab_string + group_alias
else:
groups_agg[group]["alias"] = tab_string + group
results_dict = {
"results": dict(results_agg.items()),
......
from lm_eval.api.filter import FilterEnsemble
from . import selection
from . import extraction
from . import transformation
FILTER_REGISTRY = {
......@@ -9,6 +10,9 @@ FILTER_REGISTRY = {
"majority_vote": selection.MajorityVoteFilter,
"take_first_k": selection.TakeKFilter,
"remove_whitespace": extraction.WhitespaceFilter,
"lowercase": transformation.LowercaseFilter,
"uppercase": transformation.UppercaseFilter,
"map": transformation.MapFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference.
......
from lm_eval.api.filter import Filter
class LowercaseFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def filter_set(inst):
return [resp.lower() for resp in inst]
return [filter_set(resp) for resp in resps]
class UppercaseFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def filter_set(inst):
return [resp.upper() for resp in inst]
return [filter_set(resp) for resp in resps]
class MapFilter(Filter):
def __init__(self, mapping_dict: dict = {}, default_value=None) -> None:
"""
Initializes the MapFilter with a given mapping dictionary and default value.
Args:
- mapping_dict (dict): A dictionary containing the key-value mappings.
Default is an empty dictionary.
- default_value (Any): The value to be returned when a key is not found in the mapping_dict.
Default is None.
Example:
mapper = MapFilter({'A': 1, 'B': 2}, default_value=0)
"""
assert isinstance(
mapping_dict, dict
), "Provided mapping_dict is not a dictionary"
self.mapping_dict = mapping_dict
self.default_value = default_value
def apply(self, resps, docs):
def filter_set(inst):
return [self.mapping_dict.get(resp, self.default_value) for resp in inst]
return [filter_set(resp) for resp in resps]
import logging
logging.basicConfig(
format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.INFO,
)
eval_logger = logging.getLogger("lm-eval")
SPACING = " " * 47
......@@ -3,6 +3,7 @@ from . import openai_completions
from . import textsynth
from . import dummy
from . import anthropic_llms
from . import gguf
# TODO: implement __all__
......@@ -2,9 +2,11 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from tqdm import tqdm
import time
from lm_eval.logger import eval_logger
from lm_eval import utils
from typing import List, Any, Tuple
eval_logger = utils.eval_logger
def anthropic_completion(
client, #: anthropic.Anthropic,
......@@ -138,7 +140,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.")
def greedy_until(self, requests) -> List[str]:
def generate_until(self, requests) -> List[str]:
if not requests:
return []
......@@ -164,7 +166,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
)
res.append(response)
self.cache_hook.add_partial("greedy_until", request, response)
self.cache_hook.add_partial("generate_until", request, response)
except anthropic.APIConnectionError as e: # type: ignore # noqa: F821
eval_logger.critical(f"Server unreachable: {e.__cause__}")
break
......@@ -179,7 +181,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
# Isn't used because we override generate_until
raise NotImplementedError()
def loglikelihood(self, requests):
......
......@@ -20,7 +20,7 @@ class DummyLM(LM):
return res
def greedy_until(self, requests):
def generate_until(self, requests):
res = []
for ctx, _ in requests:
......
import requests
import logging
import time
from tqdm import tqdm
from requests.exceptions import RequestException
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
logger = logging.getLogger(__name__)
def get_result(logprobs, context_length):
is_greedy = True
offsets = logprobs["text_offset"]
tokens = logprobs["tokens"]
tokens_logprobs = logprobs["token_logprobs"]
idx = 0
while offsets[idx] < context_length:
idx += 1
continuation_logprobs = sum(tokens_logprobs[idx:-1])
for i in range(idx, len(tokens)):
token = tokens[i]
top_tokens = logprobs["top_logprobs"][i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
@register_model("gguf", "ggml")
class GGUFLM(LM):
def __init__(self, base_url=None, max_length=2048, **kwargs):
super().__init__()
self.base_url = base_url
assert self.base_url, "must pass `base_url` to use GGUF LM!"
self.logprobs = 10
self.temperature = 0.0
self.max_length = max_length
def gguf_completion(
self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs
):
for _ in range(retries):
try:
prompt = context
request = {
"prompt": prompt,
"logprobs": self.logprobs,
"temperature": self.temperature,
}
if continuation:
prompt += continuation
request.update({"prompt": prompt, "max_tokens": 1, "echo": True})
if stop is not None:
request["stop"] = stop
response = requests.post(
f"{self.base_url}/v1/completions", json=request
)
response.raise_for_status()
return response.json()
except RequestException as e:
logger.error(f"RequestException: {e}")
time.sleep(delay) # wait before retrying
else:
raise Exception(f"Failed to get a valid response after {retries} retries.")
def loglikelihood(self, requests):
if not requests:
return []
res = []
for context, continuation in tqdm([req.args for req in requests]):
response = self.gguf_completion(context=context, continuation=continuation)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
logprobs = choice.get("logprobs")
if (
logprobs
and "token_logprobs" in logprobs
and logprobs["token_logprobs"]
):
logprob, is_greedy = get_result(logprobs, len(context))
res.append((logprob, is_greedy))
else:
logger.warning(
"Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list."
)
else:
logger.error(
f"Invalid response for loglikelihood. Response: {response}"
)
assert False
return res
def generate_until(self, requests):
if not requests:
return []
res = []
for request in tqdm([req.args for req in requests]):
inp = request[0]
request_args = request[1]
until = request_args.get("until", ["</s>"])
response = self.gguf_completion(context=inp, stop=until)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
if "text" in choice:
generated_text = choice["text"].strip()
res.append(generated_text)
else:
logger.error(
f"Invalid response for greedy_until. Response: {response}"
)
res.append(None) # Add default value in case of error
else:
logger.error(f"Invalid response for greedy_until. Response: {response}")
res.append(None) # Add default value in case of error
return res
def loglikelihood_rolling(self, requests):
raise NotImplementedError(
"loglikelihood_rolling not yet supported for GGUF models"
)
......@@ -16,9 +16,10 @@ from transformers.models.auto.modeling_auto import (
from lm_eval import utils
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.logger import eval_logger
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
eval_logger = utils.eval_logger
def _get_accelerate_args(
device_map_option: Optional[str] = "auto",
......@@ -413,12 +414,13 @@ class HFLM(LM):
utils.clear_torch_cache()
return batch_size
def tok_encode(self, string: str, left_truncate_len=None):
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None):
""" """
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True
if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
......@@ -502,7 +504,7 @@ class HFLM(LM):
self.tokenizer, stop, 1, context.shape[0]
)
return self.model.generate(
context,
input_ids=context,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
......@@ -533,8 +535,12 @@ class HFLM(LM):
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
whole_enc = self.tok_encode(context + continuation, add_special_tokens=False)
context_enc = self.tok_encode(context, add_special_tokens=False)
# whole_enc = self.tok_encode(context + continuation)
# context_enc = self.tok_encode(context, add_special_tokens=False)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
......@@ -614,6 +620,23 @@ class HFLM(LM):
return loglikelihoods
def _batch_scheduler(self, pos, n_reordered_requests):
sched = pos // int(len(n_reordered_requests) / self.batch_schedule)
if sched in self.batch_sizes:
return self.batch_sizes[sched]
if (len(self.batch_sizes) > 1) and (
self.batch_sizes[sched - 1] == self.max_batch_size
):
# if previous batch size is already maximal, skip recomputation
self.batch_sizes[sched] = self.max_batch_size
return self.batch_sizes[sched]
print(
f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
)
self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos)
print(f"Determined largest batch size: {self.batch_sizes[sched]}")
return self.batch_sizes[sched]
def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False, override_bs=None
):
......@@ -637,38 +660,22 @@ class HFLM(LM):
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
def _batch_scheduler(pos):
sched = pos // int(n_reordered_requests / self.batch_schedule)
if sched in self.batch_sizes:
return self.batch_sizes[sched]
if (len(self.batch_sizes) > 1) and (
self.batch_sizes[sched - 1] == self.max_batch_size
):
# if previous batch size is already maximal, skip recomputation
self.batch_sizes[sched] = self.max_batch_size
return self.batch_sizes[sched]
print(
f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
)
self.batch_sizes[sched] = self._detect_batch_size(
re_ord.get_reordered(), pos
)
print(f"Determined largest batch size: {self.batch_sizes[sched]}")
return self.batch_sizes[sched]
for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
chunks = utils.chunks(
re_ord.get_reordered(),
n=self.batch_size
if self.batch_size != "auto"
else override_bs
if override_bs is not None
else 0,
fn=_batch_scheduler
fn=self._batch_scheduler
if self.batch_size == "auto"
and n_reordered_requests > 0
and not override_bs
else None,
):
)
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)))
for chunk in chunks:
inps = []
cont_toks_list = []
inplens = []
......@@ -805,10 +812,13 @@ class HFLM(LM):
res.append(answer)
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
pbar.update(1)
pbar.close()
return re_ord.get_original(res)
def greedy_until(self, requests):
def generate_until(self, requests):
res = defaultdict(list)
re_ords = {}
......@@ -831,13 +841,26 @@ class HFLM(LM):
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
pbar = tqdm(total=len(requests), disable=(self.rank != 0))
if self.batch_size == "auto":
# using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size")
batch_size = self._detect_batch_size()
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size
# for each different set of kwargs, we execute all requests, by batch.
for key, re_ord in re_ords.items():
for chunk in utils.chunks(
chunks = utils.chunks(
re_ord.get_reordered(),
self.batch_size,
):
n=self.batch_size
if self.batch_size != "auto"
else adaptive_batch_size
if adaptive_batch_size is not None
else 0,
fn=self._batch_scheduler
if self.batch_size == "auto" and not adaptive_batch_size
else None,
)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
......@@ -864,8 +887,6 @@ class HFLM(LM):
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering
primary_until = [until[0]]
# set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
......@@ -891,7 +912,7 @@ class HFLM(LM):
cont = self._model_generate(
context=context_enc,
attention_mask=attn_masks,
stop=primary_until,
stop=until,
**kwargs,
)
......@@ -913,7 +934,7 @@ class HFLM(LM):
res[key].append(s)
self.cache_hook.add_partial(
"greedy_until", (context, gen_kwargs), s
"generate_until", (context, gen_kwargs), s
)
pbar.update(1)
# reorder this group of results back to original unsorted form
......
......@@ -203,7 +203,7 @@ class OpenaiCompletionsLM(LM):
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return re_ord.get_original(res)
def greedy_until(self, requests) -> List[str]:
def generate_until(self, requests) -> List[str]:
if not requests:
return []
res = []
......@@ -260,7 +260,7 @@ class OpenaiCompletionsLM(LM):
# partial caching
self.cache_hook.add_partial(
"greedy_until", (context, {"until": until_}), s
"generate_until", (context, {"until": until_}), s
)
res.append(s)
......@@ -271,7 +271,7 @@ class OpenaiCompletionsLM(LM):
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
# Isn't used because we override generate_until
raise NotImplementedError()
def loglikelihood_rolling(self, requests) -> List[float]:
......
......@@ -58,7 +58,7 @@ class TextSynthLM(LM):
@property
def eot_token_id(self):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
# Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
raise NotImplementedError()
@property
......@@ -72,20 +72,20 @@ class TextSynthLM(LM):
@property
def batch_size(self):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
# Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
raise NotImplementedError()
@property
def device(self):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
# Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
raise NotImplementedError()
def tok_encode(self, string: str):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
# Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
raise NotImplementedError()
def tok_decode(self, tokens):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
# Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
raise NotImplementedError()
def loglikelihood(self, requests):
......@@ -122,7 +122,7 @@ class TextSynthLM(LM):
"input tokenization support from TextSynth."
)
def greedy_until(self, requests):
def generate_until(self, requests):
if not requests:
return []
......@@ -146,7 +146,7 @@ class TextSynthLM(LM):
s = resp["text"]
res.append(s)
self.cache_hook.add_partial("greedy_until", (inp, request_args), s)
self.cache_hook.add_partial("generate_until", (inp, request_args), s)
else:
logger.error(
f"The following response does not contain generated `text`. "
......@@ -160,5 +160,5 @@ class TextSynthLM(LM):
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
# Isn't used because we override generate_until
raise NotImplementedError()
import os
import ast
from typing import Dict
from lm_eval import utils
from lm_eval.logger import eval_logger
from lm_eval.utils import eval_logger
# Prompt library.
# Stores prompts in a dictionary indexed by 2 levels:
......@@ -47,6 +48,14 @@ def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None
raise ValueError(
f"{prompt_name} not in prompt list {prompts.all_template_names}"
)
elif ".yaml" in category_name:
import yaml
with open(category_name, "rb") as file:
prompt_yaml_file = yaml.full_load(file)
prompt_string = prompt_yaml_file["prompts"][prompt_name]
return PromptString(prompt_string)
else:
try:
return PROMPT_REGISTRY[category_name][prompt_name]
......@@ -57,21 +66,62 @@ def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None
)
def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwargs):
def load_prompt_list(
use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs
):
from promptsource.templates import DatasetTemplates
category_name, prompt_name = use_prompt.split(":")
if subset_name is None:
prompts = DatasetTemplates(dataset_name=dataset_name)
else:
prompts = DatasetTemplates(dataset_name=dataset_name, subset_name=subset_name)
if category_name == "promptsource":
from promptsource.templates import DatasetTemplates
if subset_name is None:
prompts = DatasetTemplates(dataset_name=dataset_name)
else:
prompts = DatasetTemplates(
dataset_name=dataset_name, subset_name=subset_name
)
prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
elif ".yaml" in category_name:
import yaml
category_name, *prompt_name = use_prompt.split(":")
if yaml_path is not None:
category_name = os.path.realpath(os.path.join(yaml_path, category_name))
with open(category_name, "rb") as file:
prompt_yaml_file = yaml.full_load(file)
prompt_list = utils.pattern_match(
prompt_name, prompt_yaml_file["prompts"].keys()
)
# category_name, *prompt_name = use_prompt.split(":")
# TODO allow to multiple prompt naming
# if len(prompt_name) > 1:
# prompt_list = []
# for prompt in prompt_name:
# prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names))
# else:
prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
# prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
return [":".join([category_name, prompt]) for prompt in prompt_list]
class PromptString:
def __init__(self, prompt_string):
self.prompt_string = prompt_string
def apply(self, doc):
doc_to_text = self.prompt_string["doc_to_text"]
doc_to_target = self.prompt_string["doc_to_target"]
# TODO need a way to process doc_to_choice
if "doc_to_choice" in self.prompt_string:
raise "Not yet implemented to accept doc_to_choice"
text_string = utils.apply_template(doc_to_text, doc)
target_string = utils.apply_template(doc_to_target, doc)
return [text_string, target_string]
......@@ -59,6 +59,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] MGSM
- [ ] SCROLLS
- [x] Babi
- [x] Belebele
# Novel Tasks
Tasks added in the revamped harness that were not previously available. Again, a strikethrough denotes checking performed *against the original task's implementation or published results introducing the task*.
......
......@@ -4,7 +4,6 @@ from typing import List, Union, Dict
from lm_eval import utils
from lm_eval import prompts
from lm_eval.logger import eval_logger
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
from lm_eval.api.registry import (
register_task,
......@@ -14,6 +13,21 @@ from lm_eval.api.registry import (
ALL_TASKS,
)
import logging
# import python tasks
from .squadv2.task import SQuAD2
from .scrolls.task import (
QuALITY,
NarrativeQA,
ContractNLI,
GovReport,
SummScreenFD,
QMSum,
)
eval_logger = utils.eval_logger
def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type(
......@@ -27,7 +41,9 @@ def register_configurable_task(config: Dict[str, str]) -> int:
register_task(task_name)(SubClass)
if "group" in config:
if type(config["group"]) == str:
if config["group"] == config["task"]:
raise ValueError("task and group name cannot be the same")
elif type(config["group"]) == str:
group_name = [config["group"]]
else:
group_name = config["group"]
......@@ -38,18 +54,20 @@ def register_configurable_task(config: Dict[str, str]) -> int:
return 0
def register_configurable_group(config: Dict[str, str]) -> int:
def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -> int:
group = config["group"]
all_task_list = config["task"]
config_list = [task for task in all_task_list if type(task) != str]
task_list = [task for task in all_task_list if type(task) == str]
for task_config in config_list:
task_config = utils.load_yaml_config(yaml_path, task_config)
var_configs = check_prompt_config(
{
**task_config,
**{"group": group},
}
},
yaml_path=os.path.dirname(yaml_path),
)
for config in var_configs:
register_configurable_task(config)
......@@ -66,13 +84,16 @@ def register_configurable_group(config: Dict[str, str]) -> int:
return 0
def check_prompt_config(config: Dict[str, str]) -> List[Dict[str, str]]:
def check_prompt_config(
config: Dict[str, str], yaml_path: str = None
) -> List[Dict[str, str]]:
all_configs = []
if "use_prompt" in config:
prompt_list = prompts.load_prompt_list(
use_prompt=config["use_prompt"],
dataset_name=config["dataset_path"],
subset_name=config["dataset_name"] if "dataset_name" in config else None,
yaml_path=yaml_path,
)
for idx, prompt_variation in enumerate(prompt_list):
all_configs.append(
......@@ -85,11 +106,13 @@ def check_prompt_config(config: Dict[str, str]) -> List[Dict[str, str]]:
config["task"]
if "task" in config
else get_task_name_from_config(config),
prompt_variation,
prompt_variation.split("/")[-1]
if ".yaml" in prompt_variation
else prompt_variation,
]
)
},
**{"output_type": "greedy_until"},
**{"output_type": "generate_until"},
}
)
else:
......@@ -104,41 +127,64 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
return "{dataset_path}".format(**task_config)
def include_task_folder(task_dir: str, register_task=True) -> None:
def include_task_folder(task_dir: str, register_task: bool = True) -> None:
"""
Calling this function
"""
for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
for f in file_list:
if f.endswith(".yaml"):
yaml_path = os.path.join(root, f)
try:
config = utils.load_yaml_config(yaml_path)
# if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
for f in file_list:
if f.endswith(".yaml"):
yaml_path = os.path.join(root, f)
try:
config = utils.load_yaml_config(yaml_path)
if "task" not in config:
continue
all_configs = check_prompt_config(
config, yaml_path=os.path.dirname(yaml_path)
)
for config in all_configs:
if register_task:
all_configs = check_prompt_config(config)
for config in all_configs:
if type(config["task"]) == str:
register_configurable_task(config)
else:
# If a `task` in config is a list,
# that means it's a benchmark
if type(config["task"]) == list:
register_configurable_group(config)
except Exception as error:
eval_logger.warning(
"Failed to load config in\n"
f" {yaml_path}\n"
" Config will not be added to registry\n"
f" Error: {error}"
)
register_configurable_group(config, yaml_path)
# Log this silently and show it only when
# the user defines the appropriate verbosity.
except ModuleNotFoundError as e:
eval_logger.debug(
f"{yaml_path}: {e}. Config will not be added to registry."
)
except Exception as error:
import traceback
eval_logger.debug(
"Failed to load config in\n"
f" {yaml_path}\n"
" Config will not be added to registry\n"
f" Error: {error}\n"
f" Traceback: {traceback.format_exc()}"
)
return 0
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_task_folder(task_dir)
# Register Benchmarks after all tasks have been added
include_task_folder(task_dir, register_task=False)
def include_path(task_dir):
include_task_folder(task_dir)
# Register Benchmarks after all tasks have been added
include_task_folder(task_dir, register_task=False)
return 0
def initialize_tasks(verbosity="INFO"):
eval_logger.setLevel(getattr(logging, f"{verbosity}"))
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_path(task_dir)
def get_task(task_name, config):
......@@ -166,7 +212,6 @@ def get_task_name_from_object(task_object):
# TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
config = {**kwargs}
task_name_from_registry_dict = {}
......@@ -178,7 +223,6 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
for task_element in task_name_list:
if isinstance(task_element, str):
if task_element in GROUP_REGISTRY:
group_name = task_element
for task_name in GROUP_REGISTRY[task_element]:
......@@ -216,7 +260,6 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
}
elif isinstance(task_element, Task):
task_name_from_object_dict = {
**task_name_from_object_dict,
get_task_name_from_object(task_element): task_element,
......
# ASDiv
### Paper
Title: `ASDiv: A Diverse Corpus for Evaluating and Developing English Math Word Problem Solvers`
Abstract: https://arxiv.org/abs/2106.15772
ASDiv (Academia Sinica Diverse MWP Dataset) is a diverse (in terms of both language
patterns and problem types) English math word problem (MWP) corpus for evaluating
the capability of various MWP solvers. Existing MWP corpora for studying AI progress
remain limited either in language usage patterns or in problem types. We thus present
a new English MWP corpus with 2,305 MWPs that cover more text patterns and most problem
types taught in elementary school. Each MWP is annotated with its problem type and grade
level (for indicating the level of difficulty).
NOTE: We currently ignore formulas for answer generation.
Homepage: https://github.com/chaochun/nlu-asdiv-dataset
### Citation
```
@misc{miao2021diverse,
title={A Diverse Corpus for Evaluating and Developing English Math Word Problem Solvers},
author={Shen-Yun Miao and Chao-Chun Liang and Keh-Yih Su},
year={2021},
eprint={2106.15772},
archivePrefix={arXiv},
primaryClass={cs.AI}
}
```
### Groups and Tasks
#### Groups
* Not part of a group yet.
#### Tasks
* `asdiv`
### Checklist
For adding novel benchmarks/datasets to the library:
* [ ] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
task: babi
dataset_path: Muennighoff/babi
dataset_name: null
output_type: greedy_until
output_type: generate_until
training_split: train
validation_split: valid
test_split: test
......
# BigBenchHard
## Paper
Title: `Challenging BIG-Bench Tasks and Whether Chain-of-Thought Can Solve Them`
Abstract: https://arxiv.org/abs/2210.09261
A suite of 23 challenging BIG-Bench tasks which we call BIG-Bench Hard (BBH).
These are the task for which prior language model evaluations did not outperform
the average human-rater.
Homepage: https://github.com/suzgunmirac/BIG-Bench-Hard
## Citation
```
@article{suzgun2022challenging,
title={Challenging BIG-Bench Tasks and Whether Chain-of-Thought Can Solve Them},
author={Suzgun, Mirac and Scales, Nathan and Sch{\"a}rli, Nathanael and Gehrmann, Sebastian and Tay, Yi and Chung, Hyung Won and Chowdhery, Aakanksha and Le, Quoc V and Chi, Ed H and Zhou, Denny and and Wei, Jason},
journal={arXiv preprint arXiv:2210.09261},
year={2022}
}
```
### Groups and Tasks
#### Groups
- `bbh_flan_zeroshot`
- `bbh_flan_fewshot`
- `bbh_flan_cot_fewshot`
- `bbh_flan_cot_zeroshot`
#### Tasks
- ...
### Checklist
- [x] Is in Eval-harness v1.0 ?
- [ ] Has been checked for regression from v1.0?
- [ ] Has been checked for equivalence with original paper methodology?
- [ ] "Main" checked variant clearly denoted?
### Variant Wishlist
- [ ] Variant with Calculator (see https://github.com/openai/grade-school-math/blob/master/grade_school_math/calculator.py for example implementation)
- [ ] Using Verifiers
- [ ] Majority voting "without CoT"
"""
Take in a YAML, and output all other splits with this YAML
"""
import os
import re
import yaml
import requests
import argparse
import datasets
from tqdm import tqdm
from lm_eval import utils
from lm_eval.logger import eval_logger
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_yaml_path", required=True)
parser.add_argument("--save_prefix_path", default="flan_zeroshot")
parser.add_argument("--cot", default=False)
parser.add_argument("--fewshot", default=False)
parser.add_argument("--task_prefix", default="")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
base_yaml_name = os.path.split(args.base_yaml_path)[-1]
with open(args.base_yaml_path) as f:
base_yaml = yaml.full_load(f)
base_doc_to_text = "Q: {{input}}\nA:"
answer_regex = re.compile("(?<=answer is )(.*)(?=.)")
dataset_path = "lukaemon/bbh"
for task in tqdm(datasets.get_dataset_infos(dataset_path).keys()):
resp = requests.get(
f"https://raw.githubusercontent.com/suzgunmirac/BIG-Bench-Hard/main/cot-prompts/{task}.txt"
).content.decode("utf-8")
prompt = resp.split("\n-----\n")[-1]
description, *few_shot = prompt.split("\n\nQ:")
prefix_doc_to_text = ""
if args.fewshot:
if args.cot:
prefix_doc_to_text = " ".join(few_shot)
else:
for shot in few_shot:
shot = "Q:" + shot
try:
answer = answer_regex.search(shot)[0]
except Exception:
print("task", task)
print(shot)
example = shot.split("Let's think step by step.")[0]
prefix_doc_to_text += f"{example}{answer}\n\n"
doc_to_text = prefix_doc_to_text + base_doc_to_text
if args.cot:
doc_to_text = doc_to_text + " Let's think step by step.\n"
yaml_dict = {
"include": base_yaml_name,
"task": f"bbh_{args.task_prefix}_{task}",
"dataset_name": task,
"description": description + "\n\n",
"doc_to_text": doc_to_text,
}
file_save_path = args.save_prefix_path + f"/{task}.yaml"
eval_logger.info(f"Saving yaml for subset {task} to {file_save_path}")
with open(file_save_path, "w") as yaml_file:
yaml.dump(
yaml_dict,
yaml_file,
width=float("inf"),
allow_unicode=True,
default_style='"',
)
group: bbh_flan_cot_fewshot
dataset_path: lukaemon/bbh
output_type: generate_until
test_split: test
doc_to_target: "{{target}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
generation_kwargs:
until:
- "</s>"
do_sample: false
temperature: 0.0
filter_list:
- name: "get-answer"
filter:
- function: "regex"
regex_pattern: "(?<=the answer is )(.*)(?=.)"
- function: "take_first"
"dataset_name": "boolean_expressions"
"description": "Evaluate the result of a random Boolean expression.\n\n"
"doc_to_text": " not ( ( not not True ) ) is\nA: Let's think step by step.\nRemember that (i) expressions inside brackets are always evaluated first and that (ii) the order of operations from highest priority to lowest priority is \"not\", \"and\", \"or\", respectively.\nWe first simplify this expression \"Z\" as follows: \"Z = not ( ( not not True ) ) = not ( ( A ) )\" where \"A = not not True\".\nLet's evaluate A: A = not not True = not (not True) = not False = True.\nPlugging in A, we get: Z = not ( ( A ) ) = not ( ( True ) ) = not True = False. So the answer is False. True and False and not True and True is\nA: Let's think step by step.\nRemember that (i) expressions inside brackets are always evaluated first and that (ii) the order of operations from highest priority to lowest priority is \"not\", \"and\", \"or\", respectively.\nWe first simplify this expression \"Z\" as follows: \"Z = True and False and not True and True = A and B\" where \"A = True and False\" and \"B = not True and True\".\nLet's evaluate A: A = True and False = False.\nLet's evaluate B: B = not True and True = not (True and True) = not (True) = False.\nPlugging in A and B, we get: Z = A and B = False and False = False. So the answer is False. not not ( not ( False ) ) is\nA: Let's think step by step.\nRemember that (i) expressions inside brackets are always evaluated first and that (ii) the order of operations from highest priority to lowest priority is \"not\", \"and\", \"or\", respectively.\nWe first simplify this expression \"Z\" as follows: \"Z = not not ( not ( False ) ) = not not ( A )\" where \"A = not ( False )\".\nLet's evaluate A: A = not ( False ) = not False = True.\nPlugging in A, we get: Z = not not ( A ) = not not (True) = not not False = True. So the answer is True.Q: {{input}}\nA: Let's think step by step.\n"
"include": "_flan_cot_fewshot_template_yaml"
"task": "bbh_flan_cot_fewshot_boolean_expressions"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment